1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
149                                                      NamedAttribute attribute) {
150   // Verify shape.lib attribute.
151   if (attribute.first == "shape.lib") {
152     if (!op->hasTrait<OpTrait::SymbolTable>())
153       return op->emitError(
154           "shape.lib attribute may only be on op implementing SymbolTable");
155 
156     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
157       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
158       if (!symbol)
159         return op->emitError("shape function library ")
160                << symbolRef << " not found";
161       return isa<shape::FunctionLibraryOp>(symbol)
162                  ? success()
163                  : op->emitError()
164                        << symbolRef << " required to be shape function library";
165     }
166 
167     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
168       // Verify all entries are function libraries and mappings in libraries
169       // refer to unique ops.
170       DenseSet<Identifier> key;
171       for (auto it : arr) {
172         if (!it.isa<SymbolRefAttr>())
173           return op->emitError(
174               "only SymbolRefAttr allowed in shape.lib attribute array");
175 
176         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
177             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
178         if (!shapeFnLib)
179           return op->emitError()
180                  << it << " does not refer to FunctionLibraryOp";
181         for (auto mapping : shapeFnLib.mapping()) {
182           if (!key.insert(mapping.first).second) {
183             return op->emitError("only one op to shape mapping allowed, found "
184                                  "multiple for `")
185                    << mapping.first << "`";
186           }
187         }
188       }
189       return success();
190     }
191 
192     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
193                          "allowed as shape.lib attribute");
194   }
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // AnyOp
200 //===----------------------------------------------------------------------===//
201 
202 // TODO: Canonicalization should be implemented for shapes that can be
203 // determined through mixtures of the known dimensions of the inputs.
204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
205   // Only the last operand is checked because AnyOp is commutative.
206   if (operands.back())
207     return operands.back();
208 
209   return nullptr;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AssumingOp
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseAssumingOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   result.regions.reserve(1);
219   Region *doRegion = result.addRegion();
220 
221   auto &builder = parser.getBuilder();
222   OpAsmParser::OperandType cond;
223   if (parser.parseOperand(cond) ||
224       parser.resolveOperand(cond, builder.getType<WitnessType>(),
225                             result.operands))
226     return failure();
227 
228   // Parse optional results type list.
229   if (parser.parseOptionalArrowTypeList(result.types))
230     return failure();
231 
232   // Parse the region and add a terminator if elided.
233   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
234     return failure();
235   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
236 
237   // Parse the optional attribute list.
238   if (parser.parseOptionalAttrDict(result.attributes))
239     return failure();
240   return success();
241 }
242 
243 static void print(OpAsmPrinter &p, AssumingOp op) {
244   bool yieldsResults = !op.results().empty();
245 
246   p << AssumingOp::getOperationName() << " " << op.witness();
247   if (yieldsResults) {
248     p << " -> (" << op.getResultTypes() << ")";
249   }
250   p.printRegion(op.doRegion(),
251                 /*printEntryBlockArgs=*/false,
252                 /*printBlockTerminators=*/yieldsResults);
253   p.printOptionalAttrDict(op->getAttrs());
254 }
255 
256 namespace {
257 // Removes AssumingOp with a passing witness and inlines the region.
258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
259   using OpRewritePattern<AssumingOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(AssumingOp op,
262                                 PatternRewriter &rewriter) const override {
263     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
264     if (!witness || !witness.passingAttr())
265       return failure();
266 
267     AssumingOp::inlineRegionIntoParent(op, rewriter);
268     return success();
269   }
270 };
271 } // namespace
272 
273 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
274                                              MLIRContext *context) {
275   // If taking a passing witness, inline region.
276   patterns.add<AssumingWithTrue>(context);
277 }
278 
279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
280 void AssumingOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   // AssumingOp has unconditional control flow into the region and back to the
284   // parent, so return the correct RegionSuccessor purely based on the index
285   // being None or 0.
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&doRegion()));
292 }
293 
294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
295                                         PatternRewriter &rewriter) {
296   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
297   auto *assumingBlock = op.getBody();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *blockAfterAssuming =
300       rewriter.splitBlock(blockBeforeAssuming, initPosition);
301 
302   // Remove the AssumingOp and AssumingYieldOp.
303   auto &yieldOp = assumingBlock->back();
304   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
305   rewriter.replaceOp(op, yieldOp.getOperands());
306   rewriter.eraseOp(&yieldOp);
307 
308   // Merge blocks together as there was no branching behavior from the
309   // AssumingOp.
310   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
311   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
312 }
313 
314 void AssumingOp::build(
315     OpBuilder &builder, OperationState &result, Value witness,
316     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
317 
318   result.addOperands(witness);
319   Region *bodyRegion = result.addRegion();
320   bodyRegion->push_back(new Block);
321   Block &bodyBlock = bodyRegion->front();
322 
323   // Build body.
324   OpBuilder::InsertionGuard guard(builder);
325   builder.setInsertionPointToStart(&bodyBlock);
326   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
327   builder.create<AssumingYieldOp>(result.location, yieldValues);
328 
329   SmallVector<Type, 2> assumingTypes;
330   for (Value v : yieldValues)
331     assumingTypes.push_back(v.getType());
332   result.addTypes(assumingTypes);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // AssumingAllOp
337 //===----------------------------------------------------------------------===//
338 
339 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
340                                                 MLIRContext *context) {
341   patterns.add<AssumingAllOneOp>(context);
342 }
343 
344 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
345   // Iterate in reverse to first handle all constant operands. They are
346   // guaranteed to be the tail of the inputs because this is commutative.
347   for (int idx = operands.size() - 1; idx >= 0; idx--) {
348     Attribute a = operands[idx];
349     // Cannot fold if any inputs are not constant;
350     if (!a)
351       return nullptr;
352 
353     // We do not need to keep statically known values after handling them in
354     // this method.
355     getOperation()->eraseOperand(idx);
356 
357     // Always false if any input is statically known false
358     if (!a.cast<BoolAttr>().getValue())
359       return a;
360   }
361   // If this is reached, all inputs were statically known passing.
362   return BoolAttr::get(getContext(), true);
363 }
364 
365 static LogicalResult verify(AssumingAllOp op) {
366   // Ensure that AssumingAllOp contains at least one operand
367   if (op.getNumOperands() == 0)
368     return op.emitOpError("no operands specified");
369 
370   return success();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // BroadcastOp
375 //===----------------------------------------------------------------------===//
376 
377 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
378   if (operands.size() == 1)
379     return shapes().front();
380 
381   // TODO: Support folding with more than 2 input shapes
382   if (shapes().size() > 2)
383     return nullptr;
384 
385   if (!operands[1])
386     return nullptr;
387 
388   auto rhsShape = llvm::to_vector<6>(
389       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
390   if (rhsShape.empty())
391     return shapes()[0];
392 
393   if (!operands[0])
394     return nullptr;
395 
396   auto lhsShape = llvm::to_vector<6>(
397       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
398   if (lhsShape.empty())
399     return shapes()[1];
400 
401   SmallVector<int64_t, 6> resultShape;
402   // If the shapes are not compatible, we can't fold it.
403   // TODO: Fold to an "error".
404   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
405     return nullptr;
406   Builder builder(getContext());
407   return builder.getIndexTensorAttr(resultShape);
408 }
409 
410 static LogicalResult verify(BroadcastOp op) {
411   return verifyShapeOrExtentTensorOp(op);
412 }
413 
414 namespace {
415 template <typename OpTy>
416 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
417   using OpRewritePattern<OpTy>::OpRewritePattern;
418 
419   LogicalResult matchAndRewrite(OpTy op,
420                                 PatternRewriter &rewriter) const override {
421     // Find unique operands.
422     SmallVector<Value, 2> unique;
423     for (Value v : op.getOperands()) {
424       if (!llvm::is_contained(unique, v))
425         unique.push_back(v);
426     }
427 
428     // Reduce op to equivalent with unique operands.
429     if (unique.size() < op.getNumOperands()) {
430       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
431                                         op->getAttrs());
432       return success();
433     }
434 
435     return failure();
436   }
437 };
438 
439 struct BroadcastForwardSingleOperandPattern
440     : public OpRewritePattern<BroadcastOp> {
441   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
442 
443   LogicalResult matchAndRewrite(BroadcastOp op,
444                                 PatternRewriter &rewriter) const override {
445     if (op.getNumOperands() == 1) {
446       rewriter.replaceOp(op, op.shapes().front());
447       return success();
448     }
449     return failure();
450   }
451 };
452 } // namespace
453 
454 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
455                                               MLIRContext *context) {
456   patterns.add<BroadcastForwardSingleOperandPattern,
457                RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // ConcatOp
462 //===----------------------------------------------------------------------===//
463 
464 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
465   if (!operands[0] || !operands[1])
466     return nullptr;
467   auto lhsShape = llvm::to_vector<6>(
468       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
469   auto rhsShape = llvm::to_vector<6>(
470       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
471   SmallVector<int64_t, 6> resultShape;
472   resultShape.append(lhsShape.begin(), lhsShape.end());
473   resultShape.append(rhsShape.begin(), rhsShape.end());
474   Builder builder(getContext());
475   return builder.getIndexTensorAttr(resultShape);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // ConstShapeOp
480 //===----------------------------------------------------------------------===//
481 
482 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
483   p << "shape.const_shape ";
484   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
485   p << "[";
486   interleaveComma(op.shape().getValues<int64_t>(), p,
487                   [&](int64_t i) { p << i; });
488   p << "] : ";
489   p.printType(op.getType());
490 }
491 
492 static ParseResult parseConstShapeOp(OpAsmParser &parser,
493                                      OperationState &result) {
494   if (parser.parseOptionalAttrDict(result.attributes))
495     return failure();
496   // We piggy-back on ArrayAttr parsing, though we don't internally store the
497   // shape as an ArrayAttr.
498   // TODO: Implement custom parser and maybe make syntax a bit more concise.
499   Attribute extentsRaw;
500   NamedAttrList dummy;
501   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
502     return failure();
503   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
504   if (!extentsArray)
505     return failure();
506   SmallVector<int64_t, 6> ints;
507   for (Attribute extent : extentsArray) {
508     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
509     if (!attr)
510       return failure();
511     ints.push_back(attr.getInt());
512   }
513   Builder &builder = parser.getBuilder();
514   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
515   Type resultTy;
516   if (parser.parseColonType(resultTy))
517     return failure();
518   result.types.push_back(resultTy);
519   return success();
520 }
521 
522 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
523 
524 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
525                                                MLIRContext *context) {
526   patterns.add<TensorCastConstShape>(context);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // CstrBroadcastableOp
531 //===----------------------------------------------------------------------===//
532 
533 namespace {
534 // Given an input shape Value, try to obtain the shape's values.
535 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
536   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
537     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
538     if (!type.hasRank())
539       return failure();
540     shapeValues = llvm::to_vector<6>(type.getShape());
541     return success();
542   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
543     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
544     return success();
545   } else {
546     return failure();
547   }
548 }
549 } // namespace
550 
551 void CstrBroadcastableOp::getCanonicalizationPatterns(
552     RewritePatternSet &patterns, MLIRContext *context) {
553   // Canonicalization patterns have overlap with the considerations during
554   // folding in case additional shape information is inferred at some point that
555   // does not result in folding.
556   patterns.add<CstrBroadcastableEqOps,
557                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>>(context);
558 }
559 
560 // Return true if there is exactly one attribute not representing a scalar
561 // broadcast.
562 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
563   bool nonScalarSeen = false;
564   for (Attribute a : attributes) {
565     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
566       if (nonScalarSeen)
567         return false;
568       nonScalarSeen = true;
569     }
570   }
571   return true;
572 }
573 
574 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
575   // No broadcasting is needed if all operands but one are scalar.
576   if (hasAtMostSingleNonScalar(operands))
577     return BoolAttr::get(getContext(), true);
578 
579   if ([&] {
580         SmallVector<SmallVector<int64_t, 6>, 6> extents;
581         for (const auto &operand : operands) {
582           if (!operand)
583             return false;
584           extents.push_back(llvm::to_vector<6>(
585               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
586         }
587         return OpTrait::util::staticallyKnownBroadcastable(extents);
588       }())
589     return BoolAttr::get(getContext(), true);
590 
591   // Lastly, see if folding can be completed based on what constraints are known
592   // on the input shapes.
593   if ([&] {
594         SmallVector<SmallVector<int64_t, 6>, 6> extents;
595         for (auto shapeValue : shapes()) {
596           extents.emplace_back();
597           if (failed(getShapeVec(shapeValue, extents.back())))
598             return false;
599         }
600         return OpTrait::util::staticallyKnownBroadcastable(extents);
601       }())
602     return BoolAttr::get(getContext(), true);
603 
604   // Because a failing witness result here represents an eventual assertion
605   // failure, we do not replace it with a constant witness.
606   return nullptr;
607 }
608 
609 static LogicalResult verify(CstrBroadcastableOp op) {
610   // Ensure that AssumingAllOp contains at least one operand
611   if (op.getNumOperands() < 2)
612     return op.emitOpError("required at least 2 input shapes");
613   return success();
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // CstrEqOp
618 //===----------------------------------------------------------------------===//
619 
620 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
621                                            MLIRContext *context) {
622   // If inputs are equal, return passing witness
623   patterns.add<CstrEqEqOps>(context);
624 }
625 
626 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
627   if (llvm::all_of(operands,
628                    [&](Attribute a) { return a && a == operands[0]; }))
629     return BoolAttr::get(getContext(), true);
630 
631   // Because a failing witness result here represents an eventual assertion
632   // failure, we do not try to replace it with a constant witness. Similarly, we
633   // cannot if there are any non-const inputs.
634   return nullptr;
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // ConstSizeOp
639 //===----------------------------------------------------------------------===//
640 
641 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
642                         int64_t value) {
643   build(builder, result, builder.getIndexAttr(value));
644 }
645 
646 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
647 
648 void ConstSizeOp::getAsmResultNames(
649     llvm::function_ref<void(Value, StringRef)> setNameFn) {
650   SmallString<4> buffer;
651   llvm::raw_svector_ostream os(buffer);
652   os << "c" << value();
653   setNameFn(getResult(), os.str());
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // ConstWitnessOp
658 //===----------------------------------------------------------------------===//
659 
660 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
661 
662 //===----------------------------------------------------------------------===//
663 // CstrRequireOp
664 //===----------------------------------------------------------------------===//
665 
666 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
667   return operands[0];
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // DivOp
672 //===----------------------------------------------------------------------===//
673 
674 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
675   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
676   if (!lhs)
677     return nullptr;
678   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
679   if (!rhs)
680     return nullptr;
681 
682   // Division in APInt does not follow floor(lhs, rhs) when the result is
683   // negative. Rather, APInt rounds toward zero.
684   APInt quotient, remainder;
685   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
686   if (quotient.isNegative() && !remainder.isNullValue()) {
687     quotient -= 1;
688   }
689 
690   Type indexTy = IndexType::get(getContext());
691   return IntegerAttr::get(indexTy, quotient);
692 }
693 
694 //===----------------------------------------------------------------------===//
695 // ShapeEqOp
696 //===----------------------------------------------------------------------===//
697 
698 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
699   bool allSame = true;
700   if (!operands.empty() && !operands[0])
701     return {};
702   for (Attribute operand : operands.drop_front(1)) {
703     if (!operand)
704       return {};
705     allSame = allSame && operand == operands[0];
706   }
707   return BoolAttr::get(getContext(), allSame);
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // IndexToSizeOp
712 //===----------------------------------------------------------------------===//
713 
714 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
715   // Constant values of both types, `shape.size` and `index`, are represented as
716   // `IntegerAttr`s which makes constant folding simple.
717   if (Attribute arg = operands[0])
718     return arg;
719   return {};
720 }
721 
722 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
723                                                 MLIRContext *context) {
724   patterns.add<SizeToIndexToSizeCanonicalization>(context);
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // FromExtentsOp
729 //===----------------------------------------------------------------------===//
730 
731 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
732   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
733     return nullptr;
734   SmallVector<int64_t, 6> extents;
735   for (auto attr : operands)
736     extents.push_back(attr.cast<IntegerAttr>().getInt());
737   Builder builder(getContext());
738   return builder.getIndexTensorAttr(extents);
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // FunctionLibraryOp
743 //===----------------------------------------------------------------------===//
744 
745 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
746                               StringRef name) {
747   ensureTerminator(*result.addRegion(), builder, result.location);
748   result.attributes.push_back(builder.getNamedAttr(
749       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
750 }
751 
752 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
753   auto attr = mapping()
754                   .get(op->getName().getIdentifier())
755                   .dyn_cast_or_null<FlatSymbolRefAttr>();
756   if (!attr)
757     return nullptr;
758   return lookupSymbol<FuncOp>(attr);
759 }
760 
761 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
762                                    OperationState &result) {
763   // Parse the op name.
764   StringAttr nameAttr;
765   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
766                              result.attributes))
767     return failure();
768 
769   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
770     return failure();
771 
772   auto *bodyRegion = result.addRegion();
773   if (parser.parseRegion(*bodyRegion))
774     return failure();
775 
776   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
777                                       result.location);
778   if (parser.parseKeyword("mapping"))
779     return failure();
780 
781   DictionaryAttr mappingAttr;
782   if (parser.parseAttribute(mappingAttr,
783                             parser.getBuilder().getType<NoneType>(), "mapping",
784                             result.attributes))
785     return failure();
786   return success();
787 }
788 
789 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
790   p << op.getOperationName() << ' ';
791   p.printSymbolName(op.getName());
792   p.printOptionalAttrDictWithKeyword(
793       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
794   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
795                 /*printBlockTerminators=*/false);
796   p << " mapping ";
797   p.printAttributeWithoutType(op.mappingAttr());
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // GetExtentOp
802 //===----------------------------------------------------------------------===//
803 
804 Optional<int64_t> GetExtentOp::getConstantDim() {
805   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
806     return constSizeOp.value().getLimitedValue();
807   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
808     return constantOp.value().cast<IntegerAttr>().getInt();
809   return llvm::None;
810 }
811 
812 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
813   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
814   if (!elements)
815     return nullptr;
816   Optional<int64_t> dim = getConstantDim();
817   if (!dim.hasValue())
818     return nullptr;
819   if (dim.getValue() >= elements.getNumElements())
820     return nullptr;
821   return elements.getValue({(uint64_t)dim.getValue()});
822 }
823 
824 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
825                         int64_t dim) {
826   auto loc = result.location;
827   auto dimAttr = builder.getIndexAttr(dim);
828   if (shape.getType().isa<ShapeType>()) {
829     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
830     build(builder, result, builder.getType<SizeType>(), shape, dim);
831   } else {
832     Value dim =
833         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
834     build(builder, result, builder.getIndexType(), shape, dim);
835   }
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // IsBroadcastableOp
840 //===----------------------------------------------------------------------===//
841 
842 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
843                                                     MLIRContext *context) {
844   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
845 }
846 
847 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
848   // Can always broadcast fewer than two shapes.
849   if (operands.size() < 2) {
850     return BoolAttr::get(getContext(), true);
851   }
852 
853   return nullptr;
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // RankOp
858 //===----------------------------------------------------------------------===//
859 
860 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
861   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
862   if (!shape)
863     return {};
864   int64_t rank = shape.getNumElements();
865   Builder builder(getContext());
866   return builder.getIndexAttr(rank);
867 }
868 
869 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
870 /// Constant folding fails in cases where only the rank is constant, not the
871 /// shape itself.
872 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
873 ///
874 /// Example:
875 ///
876 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
877 /// %rank = shape.rank %shape
878 ///
879 /// becomes
880 ///
881 /// %rank = shape.const_size 3
882 
883 namespace {
884 struct RankShapeOfCanonicalizationPattern
885     : public OpRewritePattern<shape::RankOp> {
886   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
887 
888   LogicalResult matchAndRewrite(shape::RankOp op,
889                                 PatternRewriter &rewriter) const override {
890     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
891     if (!shapeOfOp)
892       return failure();
893     auto rankedTensorType =
894         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
895     if (!rankedTensorType)
896       return failure();
897     int64_t rank = rankedTensorType.getRank();
898     if (op.getType().isa<IndexType>()) {
899       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
900     } else if (op.getType().isa<shape::SizeType>()) {
901       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
902     } else {
903       return failure();
904     }
905     return success();
906   }
907 };
908 } // namespace
909 
910 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
911                                                 MLIRContext *context) {
912   patterns.add<RankShapeOfCanonicalizationPattern>(context);
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // NumElementsOp
917 //===----------------------------------------------------------------------===//
918 
919 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
920 
921   // Fold only when argument constant.
922   Attribute shape = operands[0];
923   if (!shape)
924     return {};
925 
926   APInt product(64, 1);
927   for (auto value : shape.cast<DenseIntElementsAttr>())
928     product *= value;
929   Builder builder(getContext());
930   return builder.getIndexAttr(product.getLimitedValue());
931 }
932 
933 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
934                           Value shape) {
935   if (shape.getType().isa<ShapedType>()) {
936     auto type = builder.getIndexType();
937     return build(builder, result, type, shape);
938   }
939   auto type = SizeType::get(builder.getContext());
940   return build(builder, result, type, shape);
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // MulOp
945 //===----------------------------------------------------------------------===//
946 
947 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
948   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
949   if (!lhs)
950     return nullptr;
951   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
952   if (!rhs)
953     return nullptr;
954   APInt folded = lhs.getValue() * rhs.getValue();
955   Type indexTy = IndexType::get(getContext());
956   return IntegerAttr::get(indexTy, folded);
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // ShapeOfOp
961 //===----------------------------------------------------------------------===//
962 
963 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
964   auto type = getOperand().getType().dyn_cast<ShapedType>();
965   if (!type || !type.hasStaticShape())
966     return nullptr;
967   Builder builder(getContext());
968   return builder.getIndexTensorAttr(type.getShape());
969 }
970 
971 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
972   Type type = arg.getType().isa<ShapedType>()
973                   ? (Type)getExtentTensorType(builder.getContext())
974                   : (Type)builder.getType<ShapeType>();
975   return ShapeOfOp::build(builder, result, type, arg);
976 }
977 
978 namespace {
979 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
980   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
981 
982   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
983                                 PatternRewriter &rewriter) const override {
984     if (!op.arg().getType().isa<ShapedType>())
985       return failure();
986     if (op.getType().isa<ShapedType>())
987       return failure();
988 
989     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
990     return success();
991   }
992 };
993 } // namespace
994 
995 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
996                                             MLIRContext *context) {
997   patterns.add<ShapeOfWithTensor>(context);
998 }
999 
1000 //===----------------------------------------------------------------------===//
1001 // SizeToIndexOp
1002 //===----------------------------------------------------------------------===//
1003 
1004 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1005   // Constant values of both types, `shape.size` and `index`, are represented as
1006   // `IntegerAttr`s which makes constant folding simple.
1007   if (Attribute arg = operands[0])
1008     return arg;
1009   return impl::foldCastOp(*this);
1010 }
1011 
1012 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1013                                                 MLIRContext *context) {
1014   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // YieldOp
1019 //===----------------------------------------------------------------------===//
1020 
1021 static LogicalResult verify(shape::YieldOp op) {
1022   auto *parentOp = op->getParentOp();
1023   auto results = parentOp->getResults();
1024   auto operands = op.getOperands();
1025 
1026   if (parentOp->getNumResults() != op.getNumOperands())
1027     return op.emitOpError() << "number of operands does not match number of "
1028                                "results of its parent";
1029   for (auto e : llvm::zip(results, operands))
1030     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1031       return op.emitOpError()
1032              << "types mismatch between yield op and its parent";
1033 
1034   return success();
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // SplitAtOp
1039 //===----------------------------------------------------------------------===//
1040 
1041 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1042                               SmallVectorImpl<OpFoldResult> &results) {
1043   if (!operands[0] || !operands[1])
1044     return failure();
1045   auto shapeVec = llvm::to_vector<6>(
1046       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1047   auto shape = llvm::makeArrayRef(shapeVec);
1048   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1049   // Verify that the split point is in the correct range.
1050   // TODO: Constant fold to an "error".
1051   int64_t rank = shape.size();
1052   if (!(-rank <= splitPoint && splitPoint <= rank))
1053     return failure();
1054   if (splitPoint < 0)
1055     splitPoint += shape.size();
1056   Builder builder(operands[0].getContext());
1057   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1058   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1059   return success();
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // ToExtentTensorOp
1064 //===----------------------------------------------------------------------===//
1065 
1066 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1067   if (!operands[0])
1068     return impl::foldCastOp(*this);
1069   Builder builder(getContext());
1070   auto shape = llvm::to_vector<6>(
1071       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1072   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1073                                     builder.getIndexType());
1074   return DenseIntElementsAttr::get(type, shape);
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // ReduceOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1082                      ValueRange initVals) {
1083   result.addOperands(shape);
1084   result.addOperands(initVals);
1085 
1086   Region *bodyRegion = result.addRegion();
1087   bodyRegion->push_back(new Block);
1088   Block &bodyBlock = bodyRegion->front();
1089   bodyBlock.addArgument(builder.getIndexType());
1090 
1091   Type elementType;
1092   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1093     elementType = tensorType.getElementType();
1094   else
1095     elementType = SizeType::get(builder.getContext());
1096   bodyBlock.addArgument(elementType);
1097 
1098   for (Type initValType : initVals.getTypes()) {
1099     bodyBlock.addArgument(initValType);
1100     result.addTypes(initValType);
1101   }
1102 }
1103 
1104 static LogicalResult verify(ReduceOp op) {
1105   // Verify block arg types.
1106   Block &block = op.region().front();
1107 
1108   // The block takes index, extent, and aggregated values as arguments.
1109   auto blockArgsCount = op.initVals().size() + 2;
1110   if (block.getNumArguments() != blockArgsCount)
1111     return op.emitOpError() << "ReduceOp body is expected to have "
1112                             << blockArgsCount << " arguments";
1113 
1114   // The first block argument is the index and must always be of type `index`.
1115   if (!block.getArgument(0).getType().isa<IndexType>())
1116     return op.emitOpError(
1117         "argument 0 of ReduceOp body is expected to be of IndexType");
1118 
1119   // The second block argument is the extent and must be of type `size` or
1120   // `index`, depending on whether the reduce operation is applied to a shape or
1121   // to an extent tensor.
1122   Type extentTy = block.getArgument(1).getType();
1123   if (op.shape().getType().isa<ShapeType>()) {
1124     if (!extentTy.isa<SizeType>())
1125       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1126                             "SizeType if the ReduceOp operates on a ShapeType");
1127   } else {
1128     if (!extentTy.isa<IndexType>())
1129       return op.emitOpError(
1130           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1131           "ReduceOp operates on an extent tensor");
1132   }
1133 
1134   for (auto type : llvm::enumerate(op.initVals()))
1135     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1136       return op.emitOpError()
1137              << "type mismatch between argument " << type.index() + 2
1138              << " of ReduceOp body and initial value " << type.index();
1139   return success();
1140 }
1141 
1142 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1143   // Parse operands.
1144   SmallVector<OpAsmParser::OperandType, 3> operands;
1145   Type shapeOrExtentTensorType;
1146   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1147                               OpAsmParser::Delimiter::Paren) ||
1148       parser.parseColonType(shapeOrExtentTensorType) ||
1149       parser.parseOptionalArrowTypeList(result.types))
1150     return failure();
1151 
1152   // Resolve operands.
1153   auto initVals = llvm::makeArrayRef(operands).drop_front();
1154   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1155                             result.operands) ||
1156       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1157                              result.operands))
1158     return failure();
1159 
1160   // Parse the body.
1161   Region *body = result.addRegion();
1162   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1163     return failure();
1164 
1165   // Parse attributes.
1166   if (parser.parseOptionalAttrDict(result.attributes))
1167     return failure();
1168 
1169   return success();
1170 }
1171 
1172 static void print(OpAsmPrinter &p, ReduceOp op) {
1173   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1174     << ") : " << op.shape().getType();
1175   p.printOptionalArrowTypeList(op.getResultTypes());
1176   p.printRegion(op.region());
1177   p.printOptionalAttrDict(op->getAttrs());
1178 }
1179 
1180 #define GET_OP_CLASSES
1181 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1182