1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 using namespace mlir;
25 using namespace mlir::shape;
26 
27 namespace {
28 #include "ShapeCanonicalization.inc"
29 }
30 
31 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
32   return RankedTensorType::get({rank}, IndexType::get(ctx));
33 }
34 
35 bool shape::isExtentTensorType(Type type) {
36   auto ranked = type.dyn_cast<RankedTensorType>();
37   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
38 }
39 
40 LogicalResult shape::getShapeVec(Value input,
41                                  SmallVectorImpl<int64_t> &shapeValues) {
42   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
43     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
44     if (!type.hasRank())
45       return failure();
46     shapeValues = llvm::to_vector<6>(type.getShape());
47     return success();
48   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
49     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
50     return success();
51   } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
52     shapeValues = llvm::to_vector<6>(
53         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
54     return success();
55   } else {
56     return failure();
57   }
58 }
59 
60 static bool isErrorPropagationPossible(TypeRange operandTypes) {
61   return llvm::any_of(operandTypes, [](Type ty) {
62     return ty.isa<SizeType, ShapeType, ValueShapeType>();
63   });
64 }
65 
66 static LogicalResult verifySizeOrIndexOp(Operation *op) {
67   assert(op != nullptr && op->getNumResults() == 1);
68   Type resultTy = op->getResultTypes().front();
69   if (isErrorPropagationPossible(op->getOperandTypes())) {
70     if (!resultTy.isa<SizeType>())
71       return op->emitOpError()
72              << "if at least one of the operands can hold error values then "
73                 "the result must be of type `size` to propagate them";
74   }
75   return success();
76 }
77 
78 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
79   assert(op != nullptr && op->getNumResults() == 1);
80   Type resultTy = op->getResultTypes().front();
81   if (isErrorPropagationPossible(op->getOperandTypes())) {
82     if (!resultTy.isa<ShapeType>())
83       return op->emitOpError()
84              << "if at least one of the operands can hold error values then "
85                 "the result must be of type `shape` to propagate them";
86   }
87   return success();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // InlinerInterface
92 //===----------------------------------------------------------------------===//
93 
94 namespace {
95 /// This class defines the interface for inlining shape dialect ops.
96 struct ShapeInlinerInterface : public DialectInlinerInterface {
97   using DialectInlinerInterface::DialectInlinerInterface;
98 
99   // Returns true if the given region 'src' can be inlined into the region
100   // 'dest' that is attached to an operation registered to the current dialect.
101   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
102                        BlockAndValueMapping &) const final {
103     return true;
104   }
105 
106   // Returns true if the given operation 'op', that is registered to this
107   // dialect, can be inlined into the region 'dest' that is attached to an
108   // operation registered to the current dialect.
109   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
110                        BlockAndValueMapping &) const final {
111     return true;
112   }
113 };
114 } // namespace
115 
116 void ShapeDialect::initialize() {
117   addOperations<
118 #define GET_OP_LIST
119 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
120       >();
121   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
122   addInterfaces<ShapeInlinerInterface>();
123   // Allow unknown operations during prototyping and testing. As the dialect is
124   // still evolving it makes it simple to start with an unregistered ops and
125   // try different variants before actually defining the op.
126   allowUnknownOperations();
127 }
128 
129 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
130                                              Attribute value, Type type,
131                                              Location loc) {
132   if (type.isa<ShapeType>() || isExtentTensorType(type))
133     return builder.create<ConstShapeOp>(loc, type,
134                                         value.cast<DenseIntElementsAttr>());
135   if (type.isa<SizeType>())
136     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
137   if (type.isa<WitnessType>())
138     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
139   if (ConstantOp::isBuildableWith(value, type))
140     return builder.create<ConstantOp>(loc, type, value);
141   return nullptr;
142 }
143 
144 /// Parse a type registered to this dialect.
145 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
146   StringRef keyword;
147   if (parser.parseKeyword(&keyword))
148     return Type();
149 
150   if (keyword == "shape")
151     return ShapeType::get(getContext());
152   if (keyword == "size")
153     return SizeType::get(getContext());
154   if (keyword == "value_shape")
155     return ValueShapeType::get(getContext());
156   if (keyword == "witness")
157     return WitnessType::get(getContext());
158 
159   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
160   return Type();
161 }
162 
163 /// Print a type registered to this dialect.
164 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
165   TypeSwitch<Type>(type)
166       .Case<ShapeType>([&](Type) { os << "shape"; })
167       .Case<SizeType>([&](Type) { os << "size"; })
168       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
169       .Case<WitnessType>([&](Type) { os << "witness"; })
170       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
171 }
172 
173 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
174                                                      NamedAttribute attribute) {
175   // Verify shape.lib attribute.
176   if (attribute.first == "shape.lib") {
177     if (!op->hasTrait<OpTrait::SymbolTable>())
178       return op->emitError(
179           "shape.lib attribute may only be on op implementing SymbolTable");
180 
181     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
182       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
183       if (!symbol)
184         return op->emitError("shape function library ")
185                << symbolRef << " not found";
186       return isa<shape::FunctionLibraryOp>(symbol)
187                  ? success()
188                  : op->emitError()
189                        << symbolRef << " required to be shape function library";
190     }
191 
192     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
193       // Verify all entries are function libraries and mappings in libraries
194       // refer to unique ops.
195       DenseSet<Identifier> key;
196       for (auto it : arr) {
197         if (!it.isa<SymbolRefAttr>())
198           return op->emitError(
199               "only SymbolRefAttr allowed in shape.lib attribute array");
200 
201         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
202             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
203         if (!shapeFnLib)
204           return op->emitError()
205                  << it << " does not refer to FunctionLibraryOp";
206         for (auto mapping : shapeFnLib.mapping()) {
207           if (!key.insert(mapping.first).second) {
208             return op->emitError("only one op to shape mapping allowed, found "
209                                  "multiple for `")
210                    << mapping.first << "`";
211           }
212         }
213       }
214       return success();
215     }
216 
217     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
218                          "allowed as shape.lib attribute");
219   }
220   return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // AnyOp
225 //===----------------------------------------------------------------------===//
226 
227 // TODO: Canonicalization should be implemented for shapes that can be
228 // determined through mixtures of the known dimensions of the inputs.
229 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
230   // Only the last operand is checked because AnyOp is commutative.
231   if (operands.back())
232     return operands.back();
233 
234   return nullptr;
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AssumingOp
239 //===----------------------------------------------------------------------===//
240 
241 static ParseResult parseAssumingOp(OpAsmParser &parser,
242                                    OperationState &result) {
243   result.regions.reserve(1);
244   Region *doRegion = result.addRegion();
245 
246   auto &builder = parser.getBuilder();
247   OpAsmParser::OperandType cond;
248   if (parser.parseOperand(cond) ||
249       parser.resolveOperand(cond, builder.getType<WitnessType>(),
250                             result.operands))
251     return failure();
252 
253   // Parse optional results type list.
254   if (parser.parseOptionalArrowTypeList(result.types))
255     return failure();
256 
257   // Parse the region and add a terminator if elided.
258   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
259     return failure();
260   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
261 
262   // Parse the optional attribute list.
263   if (parser.parseOptionalAttrDict(result.attributes))
264     return failure();
265   return success();
266 }
267 
268 static void print(OpAsmPrinter &p, AssumingOp op) {
269   bool yieldsResults = !op.results().empty();
270 
271   p << AssumingOp::getOperationName() << " " << op.witness();
272   if (yieldsResults) {
273     p << " -> (" << op.getResultTypes() << ")";
274   }
275   p.printRegion(op.doRegion(),
276                 /*printEntryBlockArgs=*/false,
277                 /*printBlockTerminators=*/yieldsResults);
278   p.printOptionalAttrDict(op->getAttrs());
279 }
280 
281 namespace {
282 // Removes AssumingOp with a passing witness and inlines the region.
283 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
284   using OpRewritePattern<AssumingOp>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(AssumingOp op,
287                                 PatternRewriter &rewriter) const override {
288     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
289     if (!witness || !witness.passingAttr())
290       return failure();
291 
292     AssumingOp::inlineRegionIntoParent(op, rewriter);
293     return success();
294   }
295 };
296 
297 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
298   using OpRewritePattern<AssumingOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(AssumingOp op,
301                                 PatternRewriter &rewriter) const override {
302     Block *body = op.getBody();
303     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
304 
305     // Find used values.
306     SmallVector<Value, 4> newYieldOperands;
307     Value opResult, yieldOperand;
308     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
309       std::tie(opResult, yieldOperand) = it;
310       if (!opResult.getUses().empty()) {
311         newYieldOperands.push_back(yieldOperand);
312       }
313     }
314 
315     // Rewrite only if redundant results exist.
316     if (newYieldOperands.size() == yieldOp->getNumOperands())
317       return failure();
318 
319     // Replace yield op in the old assuming op's body and move the entire region
320     // to the new assuming op.
321     rewriter.setInsertionPointToEnd(body);
322     auto newYieldOp =
323         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
324     rewriter.setInsertionPoint(op);
325     auto newOp = rewriter.create<AssumingOp>(
326         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
327     newOp.doRegion().takeBody(op.doRegion());
328 
329     // Use the new results to replace the previously used ones.
330     SmallVector<Value, 4> replacementValues;
331     auto src = newOp.getResults().begin();
332     for (auto it : op.getResults()) {
333       if (it.getUses().empty())
334         replacementValues.push_back(nullptr);
335       else
336         replacementValues.push_back(*src++);
337     }
338     rewriter.replaceOp(op, replacementValues);
339     return success();
340   }
341 };
342 } // namespace
343 
344 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
345                                              MLIRContext *context) {
346   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
347 }
348 
349 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
350 void AssumingOp::getSuccessorRegions(
351     Optional<unsigned> index, ArrayRef<Attribute> operands,
352     SmallVectorImpl<RegionSuccessor> &regions) {
353   // AssumingOp has unconditional control flow into the region and back to the
354   // parent, so return the correct RegionSuccessor purely based on the index
355   // being None or 0.
356   if (index.hasValue()) {
357     regions.push_back(RegionSuccessor(getResults()));
358     return;
359   }
360 
361   regions.push_back(RegionSuccessor(&doRegion()));
362 }
363 
364 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
365                                         PatternRewriter &rewriter) {
366   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
367   auto *assumingBlock = op.getBody();
368   auto initPosition = rewriter.getInsertionPoint();
369   auto *blockAfterAssuming =
370       rewriter.splitBlock(blockBeforeAssuming, initPosition);
371 
372   // Remove the AssumingOp and AssumingYieldOp.
373   auto &yieldOp = assumingBlock->back();
374   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
375   rewriter.replaceOp(op, yieldOp.getOperands());
376   rewriter.eraseOp(&yieldOp);
377 
378   // Merge blocks together as there was no branching behavior from the
379   // AssumingOp.
380   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
381   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
382 }
383 
384 void AssumingOp::build(
385     OpBuilder &builder, OperationState &result, Value witness,
386     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
387 
388   result.addOperands(witness);
389   Region *bodyRegion = result.addRegion();
390   bodyRegion->push_back(new Block);
391   Block &bodyBlock = bodyRegion->front();
392 
393   // Build body.
394   OpBuilder::InsertionGuard guard(builder);
395   builder.setInsertionPointToStart(&bodyBlock);
396   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
397   builder.create<AssumingYieldOp>(result.location, yieldValues);
398 
399   SmallVector<Type, 2> assumingTypes;
400   for (Value v : yieldValues)
401     assumingTypes.push_back(v.getType());
402   result.addTypes(assumingTypes);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // AssumingAllOp
407 //===----------------------------------------------------------------------===//
408 
409 namespace {
410 struct AssumingAllToCstrEqCanonicalization
411     : public OpRewritePattern<AssumingAllOp> {
412   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
413 
414   LogicalResult matchAndRewrite(AssumingAllOp op,
415                                 PatternRewriter &rewriter) const override {
416     SmallVector<Value, 8> shapes;
417     for (Value w : op.inputs()) {
418       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
419       if (!cstrEqOp)
420         return failure();
421       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
422         return llvm::is_contained(shapes, s);
423       });
424       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
425         return failure();
426       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
427     }
428     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
429     return success();
430   }
431 };
432 } // namespace
433 
434 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
435                                                 MLIRContext *context) {
436   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
437 }
438 
439 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
440   // Iterate in reverse to first handle all constant operands. They are
441   // guaranteed to be the tail of the inputs because this is commutative.
442   for (int idx = operands.size() - 1; idx >= 0; idx--) {
443     Attribute a = operands[idx];
444     // Cannot fold if any inputs are not constant;
445     if (!a)
446       return nullptr;
447 
448     // We do not need to keep statically known values after handling them in
449     // this method.
450     getOperation()->eraseOperand(idx);
451 
452     // Always false if any input is statically known false
453     if (!a.cast<BoolAttr>().getValue())
454       return a;
455   }
456   // If this is reached, all inputs were statically known passing.
457   return BoolAttr::get(getContext(), true);
458 }
459 
460 static LogicalResult verify(AssumingAllOp op) {
461   // Ensure that AssumingAllOp contains at least one operand
462   if (op.getNumOperands() == 0)
463     return op.emitOpError("no operands specified");
464 
465   return success();
466 }
467 
468 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
469                           ValueRange inputs) {
470   build(b, state, b.getType<WitnessType>(), inputs);
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // BroadcastOp
475 //===----------------------------------------------------------------------===//
476 
477 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
478   if (shapes().size() == 1) {
479     // Otherwise, we need a cast which would be a canonicalization, not folding.
480     if (shapes().front().getType() != getType())
481       return nullptr;
482     return shapes().front();
483   }
484 
485   // TODO: Support folding with more than 2 input shapes
486   if (shapes().size() > 2)
487     return nullptr;
488 
489   if (!operands[0] || !operands[1])
490     return nullptr;
491   auto lhsShape = llvm::to_vector<6>(
492       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
493   auto rhsShape = llvm::to_vector<6>(
494       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
495   SmallVector<int64_t, 6> resultShape;
496 
497   // If the shapes are not compatible, we can't fold it.
498   // TODO: Fold to an "error".
499   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
500     return nullptr;
501 
502   Builder builder(getContext());
503   return builder.getIndexTensorAttr(resultShape);
504 }
505 
506 static LogicalResult verify(BroadcastOp op) {
507   return verifyShapeOrExtentTensorOp(op);
508 }
509 
510 namespace {
511 template <typename OpTy>
512 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
513   using OpRewritePattern<OpTy>::OpRewritePattern;
514 
515   LogicalResult matchAndRewrite(OpTy op,
516                                 PatternRewriter &rewriter) const override {
517     // Find unique operands.
518     SmallVector<Value, 2> unique;
519     for (Value v : op.getOperands()) {
520       if (!llvm::is_contained(unique, v))
521         unique.push_back(v);
522     }
523 
524     // Reduce op to equivalent with unique operands.
525     if (unique.size() < op.getNumOperands()) {
526       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
527                                         op->getAttrs());
528       return success();
529     }
530 
531     return failure();
532   }
533 };
534 
535 template <typename OpTy>
536 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
537   using OpRewritePattern<OpTy>::OpRewritePattern;
538 
539   LogicalResult matchAndRewrite(OpTy op,
540                                 PatternRewriter &rewriter) const override {
541     auto isPotentiallyNonEmptyShape = [](Value shape) {
542       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
543         if (extentTensorTy.getDimSize(0) == 0)
544           return false;
545       }
546       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
547         if (constShape.shape().empty())
548           return false;
549       }
550       return true;
551     };
552     auto newOperands = llvm::to_vector<8>(
553         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
554 
555     // Reduce op to equivalent without empty shape operands.
556     if (newOperands.size() < op.getNumOperands()) {
557       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
558                                         op->getAttrs());
559       return success();
560     }
561 
562     return failure();
563   }
564 };
565 
566 struct BroadcastForwardSingleOperandPattern
567     : public OpRewritePattern<BroadcastOp> {
568   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
569 
570   LogicalResult matchAndRewrite(BroadcastOp op,
571                                 PatternRewriter &rewriter) const override {
572     if (op.getNumOperands() != 1)
573       return failure();
574     Value replacement = op.shapes().front();
575 
576     // Insert cast if needed.
577     if (replacement.getType() != op.getType()) {
578       auto loc = op.getLoc();
579       if (op.getType().isa<ShapeType>()) {
580         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
581       } else {
582         assert(!op.getType().isa<ShapeType>() &&
583                !replacement.getType().isa<ShapeType>() &&
584                "expect extent tensor cast");
585         replacement =
586             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
587       }
588     }
589 
590     rewriter.replaceOp(op, replacement);
591     return success();
592   }
593 };
594 
595 struct BroadcastFoldConstantOperandsPattern
596     : public OpRewritePattern<BroadcastOp> {
597   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
598 
599   LogicalResult matchAndRewrite(BroadcastOp op,
600                                 PatternRewriter &rewriter) const override {
601     SmallVector<int64_t, 8> foldedConstantShape;
602     SmallVector<Value, 8> newShapeOperands;
603     for (Value shape : op.shapes()) {
604       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
605         SmallVector<int64_t, 8> newFoldedConstantShape;
606         if (OpTrait::util::getBroadcastedShape(
607                 foldedConstantShape,
608                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
609                 newFoldedConstantShape)) {
610           foldedConstantShape = newFoldedConstantShape;
611           continue;
612         }
613       }
614       newShapeOperands.push_back(shape);
615     }
616 
617     // Need at least two constant operands to fold anything.
618     if (op.getNumOperands() - newShapeOperands.size() < 2)
619       return failure();
620 
621     auto foldedConstantOperandsTy = RankedTensorType::get(
622         {static_cast<int64_t>(foldedConstantShape.size())},
623         rewriter.getIndexType());
624     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
625         op.getLoc(), foldedConstantOperandsTy,
626         rewriter.getIndexTensorAttr(foldedConstantShape)));
627     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
628                                              newShapeOperands);
629     return success();
630   }
631 };
632 
633 template <typename OpTy>
634 struct CanonicalizeCastExtentTensorOperandsPattern
635     : public OpRewritePattern<OpTy> {
636   using OpRewritePattern<OpTy>::OpRewritePattern;
637 
638   LogicalResult matchAndRewrite(OpTy op,
639                                 PatternRewriter &rewriter) const override {
640     // Canonicalize operands.
641     bool anyChange = false;
642     auto canonicalizeOperand = [&](Value operand) {
643       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
644         // Only eliminate the cast if it holds no shape information.
645         bool isInformationLoosingCast =
646             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
647         if (isInformationLoosingCast) {
648           anyChange = true;
649           return castOp.source();
650         }
651       }
652       return operand;
653     };
654     auto newOperands = llvm::to_vector<8>(
655         llvm::map_range(op.getOperands(), canonicalizeOperand));
656 
657     // Rewrite op if any change required.
658     if (!anyChange)
659       return failure();
660     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
661     return success();
662   }
663 };
664 
665 struct BroadcastConcretizeResultTypePattern
666     : public OpRewritePattern<BroadcastOp> {
667   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
668 
669   LogicalResult matchAndRewrite(BroadcastOp op,
670                                 PatternRewriter &rewriter) const override {
671     // Only concretize dynamic extent tensor result types.
672     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
673     if (!resultTy || !resultTy.isDynamicDim(0))
674       return failure();
675 
676     // Infer resulting shape rank if possible.
677     int64_t maxRank = 0;
678     for (Value shape : op.shapes()) {
679       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
680         // Cannot infer resulting shape rank if any operand is dynamically
681         // ranked.
682         if (extentTensorTy.isDynamicDim(0))
683           return failure();
684         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
685       }
686     }
687 
688     auto newOp = rewriter.create<BroadcastOp>(
689         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
690     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
691     return success();
692   }
693 };
694 } // namespace
695 
696 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
697                                               MLIRContext *context) {
698   patterns.add<BroadcastConcretizeResultTypePattern,
699                BroadcastFoldConstantOperandsPattern,
700                BroadcastForwardSingleOperandPattern,
701                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
702                RemoveDuplicateOperandsPattern<BroadcastOp>,
703                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ConcatOp
708 //===----------------------------------------------------------------------===//
709 
710 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
711   if (!operands[0] || !operands[1])
712     return nullptr;
713   auto lhsShape = llvm::to_vector<6>(
714       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
715   auto rhsShape = llvm::to_vector<6>(
716       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
717   SmallVector<int64_t, 6> resultShape;
718   resultShape.append(lhsShape.begin(), lhsShape.end());
719   resultShape.append(rhsShape.begin(), rhsShape.end());
720   Builder builder(getContext());
721   return builder.getIndexTensorAttr(resultShape);
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // ConstShapeOp
726 //===----------------------------------------------------------------------===//
727 
728 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
729   p << "shape.const_shape ";
730   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
731   p << "[";
732   interleaveComma(op.shape().getValues<int64_t>(), p,
733                   [&](int64_t i) { p << i; });
734   p << "] : ";
735   p.printType(op.getType());
736 }
737 
738 static ParseResult parseConstShapeOp(OpAsmParser &parser,
739                                      OperationState &result) {
740   if (parser.parseOptionalAttrDict(result.attributes))
741     return failure();
742   // We piggy-back on ArrayAttr parsing, though we don't internally store the
743   // shape as an ArrayAttr.
744   // TODO: Implement custom parser and maybe make syntax a bit more concise.
745   Attribute extentsRaw;
746   NamedAttrList dummy;
747   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
748     return failure();
749   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
750   if (!extentsArray)
751     return failure();
752   SmallVector<int64_t, 6> ints;
753   for (Attribute extent : extentsArray) {
754     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
755     if (!attr)
756       return failure();
757     ints.push_back(attr.getInt());
758   }
759   Builder &builder = parser.getBuilder();
760   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
761   Type resultTy;
762   if (parser.parseColonType(resultTy))
763     return failure();
764   result.types.push_back(resultTy);
765   return success();
766 }
767 
768 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
769 
770 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
771                                                MLIRContext *context) {
772   patterns.add<TensorCastConstShape>(context);
773 }
774 
775 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
776     MLIRContext *context, Optional<Location> location, ValueRange operands,
777     DictionaryAttr attributes, RegionRange regions,
778     SmallVectorImpl<Type> &inferredReturnTypes) {
779   Builder b(context);
780   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
781   if (!shape)
782     return emitOptionalError(location, "missing shape attribute");
783   inferredReturnTypes.assign({RankedTensorType::get(
784       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
785   return success();
786 }
787 
788 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
789                                                         TypeRange r) {
790   if (l.size() != 1 || r.size() != 1)
791     return false;
792 
793   Type lhs = l.front();
794   Type rhs = r.front();
795 
796   if (lhs == rhs)
797     return true;
798 
799   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
800     // Shape type is compatible with all other valid return types.
801     return true;
802 
803   return succeeded(verifyCompatibleShapes(lhs, rhs));
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // CstrBroadcastableOp
808 //===----------------------------------------------------------------------===//
809 
810 void CstrBroadcastableOp::getCanonicalizationPatterns(
811     RewritePatternSet &patterns, MLIRContext *context) {
812   // Canonicalization patterns have overlap with the considerations during
813   // folding in case additional shape information is inferred at some point that
814   // does not result in folding.
815   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
816                CstrBroadcastableEqOps,
817                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
818                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
819 }
820 
821 // Return true if there is exactly one attribute not representing a scalar
822 // broadcast.
823 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
824   bool nonScalarSeen = false;
825   for (Attribute a : attributes) {
826     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
827       if (nonScalarSeen)
828         return false;
829       nonScalarSeen = true;
830     }
831   }
832   return true;
833 }
834 
835 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
836   // No broadcasting is needed if all operands but one are scalar.
837   if (hasAtMostSingleNonScalar(operands))
838     return BoolAttr::get(getContext(), true);
839 
840   if ([&] {
841         SmallVector<SmallVector<int64_t, 6>, 6> extents;
842         for (const auto &operand : operands) {
843           if (!operand)
844             return false;
845           extents.push_back(llvm::to_vector<6>(
846               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
847         }
848         return OpTrait::util::staticallyKnownBroadcastable(extents);
849       }())
850     return BoolAttr::get(getContext(), true);
851 
852   // Lastly, see if folding can be completed based on what constraints are known
853   // on the input shapes.
854   if ([&] {
855         SmallVector<SmallVector<int64_t, 6>, 6> extents;
856         for (auto shapeValue : shapes()) {
857           extents.emplace_back();
858           if (failed(getShapeVec(shapeValue, extents.back())))
859             return false;
860         }
861         return OpTrait::util::staticallyKnownBroadcastable(extents);
862       }())
863     return BoolAttr::get(getContext(), true);
864 
865   // Because a failing witness result here represents an eventual assertion
866   // failure, we do not replace it with a constant witness.
867   return nullptr;
868 }
869 
870 static LogicalResult verify(CstrBroadcastableOp op) {
871   // Ensure that AssumingAllOp contains at least one operand
872   if (op.getNumOperands() < 2)
873     return op.emitOpError("required at least 2 input shapes");
874   return success();
875 }
876 
877 //===----------------------------------------------------------------------===//
878 // CstrEqOp
879 //===----------------------------------------------------------------------===//
880 
881 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
882                                            MLIRContext *context) {
883   // If inputs are equal, return passing witness
884   patterns.add<CstrEqEqOps>(context);
885 }
886 
887 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
888   if (llvm::all_of(operands,
889                    [&](Attribute a) { return a && a == operands[0]; }))
890     return BoolAttr::get(getContext(), true);
891 
892   // Because a failing witness result here represents an eventual assertion
893   // failure, we do not try to replace it with a constant witness. Similarly, we
894   // cannot if there are any non-const inputs.
895   return nullptr;
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // ConstSizeOp
900 //===----------------------------------------------------------------------===//
901 
902 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
903                         int64_t value) {
904   build(builder, result, builder.getIndexAttr(value));
905 }
906 
907 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
908 
909 void ConstSizeOp::getAsmResultNames(
910     llvm::function_ref<void(Value, StringRef)> setNameFn) {
911   SmallString<4> buffer;
912   llvm::raw_svector_ostream os(buffer);
913   os << "c" << value();
914   setNameFn(getResult(), os.str());
915 }
916 
917 //===----------------------------------------------------------------------===//
918 // ConstWitnessOp
919 //===----------------------------------------------------------------------===//
920 
921 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
922 
923 //===----------------------------------------------------------------------===//
924 // CstrRequireOp
925 //===----------------------------------------------------------------------===//
926 
927 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
928   return operands[0];
929 }
930 
931 //===----------------------------------------------------------------------===//
932 // DivOp
933 //===----------------------------------------------------------------------===//
934 
935 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
936   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
937   if (!lhs)
938     return nullptr;
939   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
940   if (!rhs)
941     return nullptr;
942 
943   // Division in APInt does not follow floor(lhs, rhs) when the result is
944   // negative. Rather, APInt rounds toward zero.
945   APInt quotient, remainder;
946   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
947   if (quotient.isNegative() && !remainder.isNullValue()) {
948     quotient -= 1;
949   }
950 
951   Type indexTy = IndexType::get(getContext());
952   return IntegerAttr::get(indexTy, quotient);
953 }
954 
955 //===----------------------------------------------------------------------===//
956 // ShapeEqOp
957 //===----------------------------------------------------------------------===//
958 
959 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
960   bool allSame = true;
961   if (!operands.empty() && !operands[0])
962     return {};
963   for (Attribute operand : operands.drop_front(1)) {
964     if (!operand)
965       return {};
966     allSame = allSame && operand == operands[0];
967   }
968   return BoolAttr::get(getContext(), allSame);
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // IndexToSizeOp
973 //===----------------------------------------------------------------------===//
974 
975 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
976   // Constant values of both types, `shape.size` and `index`, are represented as
977   // `IntegerAttr`s which makes constant folding simple.
978   if (Attribute arg = operands[0])
979     return arg;
980   return {};
981 }
982 
983 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
984                                                 MLIRContext *context) {
985   patterns.add<SizeToIndexToSizeCanonicalization>(context);
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // FromExtentsOp
990 //===----------------------------------------------------------------------===//
991 
992 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
993   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
994     return nullptr;
995   SmallVector<int64_t, 6> extents;
996   for (auto attr : operands)
997     extents.push_back(attr.cast<IntegerAttr>().getInt());
998   Builder builder(getContext());
999   return builder.getIndexTensorAttr(extents);
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // FunctionLibraryOp
1004 //===----------------------------------------------------------------------===//
1005 
1006 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1007                               StringRef name) {
1008   result.attributes.push_back(builder.getNamedAttr(
1009       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1010 }
1011 
1012 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1013   auto attr = mapping()
1014                   .get(op->getName().getIdentifier())
1015                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1016   if (!attr)
1017     return nullptr;
1018   return lookupSymbol<FuncOp>(attr);
1019 }
1020 
1021 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1022                                    OperationState &result) {
1023   // Parse the op name.
1024   StringAttr nameAttr;
1025   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1026                              result.attributes))
1027     return failure();
1028 
1029   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1030     return failure();
1031 
1032   auto *bodyRegion = result.addRegion();
1033   if (parser.parseRegion(*bodyRegion))
1034     return failure();
1035 
1036   if (parser.parseKeyword("mapping"))
1037     return failure();
1038 
1039   DictionaryAttr mappingAttr;
1040   if (parser.parseAttribute(mappingAttr,
1041                             parser.getBuilder().getType<NoneType>(), "mapping",
1042                             result.attributes))
1043     return failure();
1044   return success();
1045 }
1046 
1047 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1048   p << op.getOperationName() << ' ';
1049   p.printSymbolName(op.getName());
1050   p.printOptionalAttrDictWithKeyword(
1051       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1052   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1053                 /*printBlockTerminators=*/false);
1054   p << " mapping ";
1055   p.printAttributeWithoutType(op.mappingAttr());
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // GetExtentOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 Optional<int64_t> GetExtentOp::getConstantDim() {
1063   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1064     return constSizeOp.value().getLimitedValue();
1065   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1066     return constantOp.value().cast<IntegerAttr>().getInt();
1067   return llvm::None;
1068 }
1069 
1070 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1071   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1072   if (!elements)
1073     return nullptr;
1074   Optional<int64_t> dim = getConstantDim();
1075   if (!dim.hasValue())
1076     return nullptr;
1077   if (dim.getValue() >= elements.getNumElements())
1078     return nullptr;
1079   return elements.getValue({(uint64_t)dim.getValue()});
1080 }
1081 
1082 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1083                         int64_t dim) {
1084   auto loc = result.location;
1085   auto dimAttr = builder.getIndexAttr(dim);
1086   if (shape.getType().isa<ShapeType>()) {
1087     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1088     build(builder, result, builder.getType<SizeType>(), shape, dim);
1089   } else {
1090     Value dim =
1091         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1092     build(builder, result, builder.getIndexType(), shape, dim);
1093   }
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // IsBroadcastableOp
1098 //===----------------------------------------------------------------------===//
1099 
1100 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1101                                                     MLIRContext *context) {
1102   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1103 }
1104 
1105 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1106   // Can always broadcast fewer than two shapes.
1107   if (operands.size() < 2) {
1108     return BoolAttr::get(getContext(), true);
1109   }
1110 
1111   return nullptr;
1112 }
1113 
1114 //===----------------------------------------------------------------------===//
1115 // RankOp
1116 //===----------------------------------------------------------------------===//
1117 
1118 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1119   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1120   if (!shape)
1121     return {};
1122   int64_t rank = shape.getNumElements();
1123   Builder builder(getContext());
1124   return builder.getIndexAttr(rank);
1125 }
1126 
1127 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1128 /// Constant folding fails in cases where only the rank is constant, not the
1129 /// shape itself.
1130 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1131 ///
1132 /// Example:
1133 ///
1134 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1135 /// %rank = shape.rank %shape
1136 ///
1137 /// becomes
1138 ///
1139 /// %rank = shape.const_size 3
1140 
1141 namespace {
1142 struct RankShapeOfCanonicalizationPattern
1143     : public OpRewritePattern<shape::RankOp> {
1144   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1145 
1146   LogicalResult matchAndRewrite(shape::RankOp op,
1147                                 PatternRewriter &rewriter) const override {
1148     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1149     if (!shapeOfOp)
1150       return failure();
1151     auto rankedTensorType =
1152         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1153     if (!rankedTensorType)
1154       return failure();
1155     int64_t rank = rankedTensorType.getRank();
1156     if (op.getType().isa<IndexType>()) {
1157       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1158     } else if (op.getType().isa<shape::SizeType>()) {
1159       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1160     } else {
1161       return failure();
1162     }
1163     return success();
1164   }
1165 };
1166 } // namespace
1167 
1168 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1169                                                 MLIRContext *context) {
1170   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // NumElementsOp
1175 //===----------------------------------------------------------------------===//
1176 
1177 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1178 
1179   // Fold only when argument constant.
1180   Attribute shape = operands[0];
1181   if (!shape)
1182     return {};
1183 
1184   APInt product(64, 1);
1185   for (auto value : shape.cast<DenseIntElementsAttr>())
1186     product *= value;
1187   Builder builder(getContext());
1188   return builder.getIndexAttr(product.getLimitedValue());
1189 }
1190 
1191 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
1192                           Value shape) {
1193   if (shape.getType().isa<ShapedType>()) {
1194     auto type = builder.getIndexType();
1195     return build(builder, result, type, shape);
1196   }
1197   auto type = SizeType::get(builder.getContext());
1198   return build(builder, result, type, shape);
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // MaxOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1206   // If operands are equal, just propagate one.
1207   if (lhs() == rhs())
1208     return lhs();
1209   return nullptr;
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // MinOp
1214 //===----------------------------------------------------------------------===//
1215 
1216 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1217   // If operands are equal, just propagate one.
1218   if (lhs() == rhs())
1219     return lhs();
1220   return nullptr;
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // MulOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1228   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1229   if (!lhs)
1230     return nullptr;
1231   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1232   if (!rhs)
1233     return nullptr;
1234   APInt folded = lhs.getValue() * rhs.getValue();
1235   Type indexTy = IndexType::get(getContext());
1236   return IntegerAttr::get(indexTy, folded);
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // ShapeOfOp
1241 //===----------------------------------------------------------------------===//
1242 
1243 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1244   auto type = getOperand().getType().dyn_cast<ShapedType>();
1245   if (!type || !type.hasStaticShape())
1246     return nullptr;
1247   Builder builder(getContext());
1248   return builder.getIndexTensorAttr(type.getShape());
1249 }
1250 
1251 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
1252   if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
1253     int64_t rank =
1254         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1255     Type indexTy = builder.getIndexType();
1256     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1257     return ShapeOfOp::build(builder, result, extentTensorTy, arg);
1258   }
1259   Type shapeTy = builder.getType<ShapeType>();
1260   return ShapeOfOp::build(builder, result, shapeTy, arg);
1261 }
1262 
1263 namespace {
1264 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1265   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1266 
1267   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1268                                 PatternRewriter &rewriter) const override {
1269     if (!op.arg().getType().isa<ShapedType>())
1270       return failure();
1271     if (op.getType().isa<ShapedType>())
1272       return failure();
1273 
1274     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1275     return success();
1276   }
1277 };
1278 
1279 // Canonicalize
1280 // ```
1281 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1282 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1283 // ```
1284 // to
1285 // ```
1286 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1287 // ```
1288 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1289   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1290 
1291   LogicalResult matchAndRewrite(tensor::CastOp op,
1292                                 PatternRewriter &rewriter) const override {
1293     auto ty = op.getType().dyn_cast<RankedTensorType>();
1294     if (!ty || ty.getRank() != 1)
1295       return failure();
1296 
1297     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1298     if (!shapeOfOp)
1299       return failure();
1300 
1301     // Argument type must be ranked and must not conflict.
1302     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1303     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1304       return failure();
1305 
1306     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1307     return success();
1308   }
1309 };
1310 } // namespace
1311 
1312 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1313                                             MLIRContext *context) {
1314   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // SizeToIndexOp
1319 //===----------------------------------------------------------------------===//
1320 
1321 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1322   // Constant values of both types, `shape.size` and `index`, are represented as
1323   // `IntegerAttr`s which makes constant folding simple.
1324   if (Attribute arg = operands[0])
1325     return arg;
1326   return impl::foldCastOp(*this);
1327 }
1328 
1329 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1330                                                 MLIRContext *context) {
1331   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // YieldOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 static LogicalResult verify(shape::YieldOp op) {
1339   auto *parentOp = op->getParentOp();
1340   auto results = parentOp->getResults();
1341   auto operands = op.getOperands();
1342 
1343   if (parentOp->getNumResults() != op.getNumOperands())
1344     return op.emitOpError() << "number of operands does not match number of "
1345                                "results of its parent";
1346   for (auto e : llvm::zip(results, operands))
1347     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1348       return op.emitOpError()
1349              << "types mismatch between yield op and its parent";
1350 
1351   return success();
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // SplitAtOp
1356 //===----------------------------------------------------------------------===//
1357 
1358 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1359                               SmallVectorImpl<OpFoldResult> &results) {
1360   if (!operands[0] || !operands[1])
1361     return failure();
1362   auto shapeVec = llvm::to_vector<6>(
1363       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1364   auto shape = llvm::makeArrayRef(shapeVec);
1365   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1366   // Verify that the split point is in the correct range.
1367   // TODO: Constant fold to an "error".
1368   int64_t rank = shape.size();
1369   if (!(-rank <= splitPoint && splitPoint <= rank))
1370     return failure();
1371   if (splitPoint < 0)
1372     splitPoint += shape.size();
1373   Builder builder(operands[0].getContext());
1374   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1375   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1376   return success();
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ToExtentTensorOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1384   if (!operands[0])
1385     return impl::foldCastOp(*this);
1386   Builder builder(getContext());
1387   auto shape = llvm::to_vector<6>(
1388       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1389   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1390                                     builder.getIndexType());
1391   return DenseIntElementsAttr::get(type, shape);
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // ReduceOp
1396 //===----------------------------------------------------------------------===//
1397 
1398 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1399                      ValueRange initVals) {
1400   result.addOperands(shape);
1401   result.addOperands(initVals);
1402 
1403   Region *bodyRegion = result.addRegion();
1404   bodyRegion->push_back(new Block);
1405   Block &bodyBlock = bodyRegion->front();
1406   bodyBlock.addArgument(builder.getIndexType());
1407 
1408   Type elementType;
1409   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1410     elementType = tensorType.getElementType();
1411   else
1412     elementType = SizeType::get(builder.getContext());
1413   bodyBlock.addArgument(elementType);
1414 
1415   for (Type initValType : initVals.getTypes()) {
1416     bodyBlock.addArgument(initValType);
1417     result.addTypes(initValType);
1418   }
1419 }
1420 
1421 static LogicalResult verify(ReduceOp op) {
1422   // Verify block arg types.
1423   Block &block = op.region().front();
1424 
1425   // The block takes index, extent, and aggregated values as arguments.
1426   auto blockArgsCount = op.initVals().size() + 2;
1427   if (block.getNumArguments() != blockArgsCount)
1428     return op.emitOpError() << "ReduceOp body is expected to have "
1429                             << blockArgsCount << " arguments";
1430 
1431   // The first block argument is the index and must always be of type `index`.
1432   if (!block.getArgument(0).getType().isa<IndexType>())
1433     return op.emitOpError(
1434         "argument 0 of ReduceOp body is expected to be of IndexType");
1435 
1436   // The second block argument is the extent and must be of type `size` or
1437   // `index`, depending on whether the reduce operation is applied to a shape or
1438   // to an extent tensor.
1439   Type extentTy = block.getArgument(1).getType();
1440   if (op.shape().getType().isa<ShapeType>()) {
1441     if (!extentTy.isa<SizeType>())
1442       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1443                             "SizeType if the ReduceOp operates on a ShapeType");
1444   } else {
1445     if (!extentTy.isa<IndexType>())
1446       return op.emitOpError(
1447           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1448           "ReduceOp operates on an extent tensor");
1449   }
1450 
1451   for (auto type : llvm::enumerate(op.initVals()))
1452     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1453       return op.emitOpError()
1454              << "type mismatch between argument " << type.index() + 2
1455              << " of ReduceOp body and initial value " << type.index();
1456   return success();
1457 }
1458 
1459 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1460   // Parse operands.
1461   SmallVector<OpAsmParser::OperandType, 3> operands;
1462   Type shapeOrExtentTensorType;
1463   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1464                               OpAsmParser::Delimiter::Paren) ||
1465       parser.parseColonType(shapeOrExtentTensorType) ||
1466       parser.parseOptionalArrowTypeList(result.types))
1467     return failure();
1468 
1469   // Resolve operands.
1470   auto initVals = llvm::makeArrayRef(operands).drop_front();
1471   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1472                             result.operands) ||
1473       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1474                              result.operands))
1475     return failure();
1476 
1477   // Parse the body.
1478   Region *body = result.addRegion();
1479   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1480     return failure();
1481 
1482   // Parse attributes.
1483   if (parser.parseOptionalAttrDict(result.attributes))
1484     return failure();
1485 
1486   return success();
1487 }
1488 
1489 static void print(OpAsmPrinter &p, ReduceOp op) {
1490   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1491     << ") : " << op.shape().getType();
1492   p.printOptionalArrowTypeList(op.getResultTypes());
1493   p.printRegion(op.region());
1494   p.printOptionalAttrDict(op->getAttrs());
1495 }
1496 
1497 #define GET_OP_CLASSES
1498 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1499