1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
149                                                      NamedAttribute attribute) {
150   // Verify shape.lib attribute.
151   if (attribute.first == "shape.lib") {
152     if (!op->hasTrait<OpTrait::SymbolTable>())
153       return op->emitError(
154           "shape.lib attribute may only be on op implementing SymbolTable");
155 
156     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
157       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
158       if (!symbol)
159         return op->emitError("shape function library ")
160                << symbolRef << " not found";
161       return isa<shape::FunctionLibraryOp>(symbol)
162                  ? success()
163                  : op->emitError()
164                        << symbolRef << " required to be shape function library";
165     }
166 
167     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
168       // Verify all entries are function libraries and mappings in libraries
169       // refer to unique ops.
170       DenseSet<Identifier> key;
171       for (auto it : arr) {
172         if (!it.isa<SymbolRefAttr>())
173           return op->emitError(
174               "only SymbolRefAttr allowed in shape.lib attribute array");
175 
176         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
177             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
178         if (!shapeFnLib)
179           return op->emitError()
180                  << it << " does not refer to FunctionLibraryOp";
181         for (auto mapping : shapeFnLib.mapping()) {
182           if (!key.insert(mapping.first).second) {
183             return op->emitError("only one op to shape mapping allowed, found "
184                                  "multiple for `")
185                    << mapping.first << "`";
186           }
187         }
188       }
189       return success();
190     }
191 
192     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
193                          "allowed as shape.lib attribute");
194   }
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // AnyOp
200 //===----------------------------------------------------------------------===//
201 
202 // TODO: Canonicalization should be implemented for shapes that can be
203 // determined through mixtures of the known dimensions of the inputs.
204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
205   // Only the last operand is checked because AnyOp is commutative.
206   if (operands.back())
207     return operands.back();
208 
209   return nullptr;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AssumingOp
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseAssumingOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   result.regions.reserve(1);
219   Region *doRegion = result.addRegion();
220 
221   auto &builder = parser.getBuilder();
222   OpAsmParser::OperandType cond;
223   if (parser.parseOperand(cond) ||
224       parser.resolveOperand(cond, builder.getType<WitnessType>(),
225                             result.operands))
226     return failure();
227 
228   // Parse optional results type list.
229   if (parser.parseOptionalArrowTypeList(result.types))
230     return failure();
231 
232   // Parse the region and add a terminator if elided.
233   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
234     return failure();
235   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
236 
237   // Parse the optional attribute list.
238   if (parser.parseOptionalAttrDict(result.attributes))
239     return failure();
240   return success();
241 }
242 
243 static void print(OpAsmPrinter &p, AssumingOp op) {
244   bool yieldsResults = !op.results().empty();
245 
246   p << AssumingOp::getOperationName() << " " << op.witness();
247   if (yieldsResults) {
248     p << " -> (" << op.getResultTypes() << ")";
249   }
250   p.printRegion(op.doRegion(),
251                 /*printEntryBlockArgs=*/false,
252                 /*printBlockTerminators=*/yieldsResults);
253   p.printOptionalAttrDict(op->getAttrs());
254 }
255 
256 namespace {
257 // Removes AssumingOp with a passing witness and inlines the region.
258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
259   using OpRewritePattern<AssumingOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(AssumingOp op,
262                                 PatternRewriter &rewriter) const override {
263     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
264     if (!witness || !witness.passingAttr())
265       return failure();
266 
267     AssumingOp::inlineRegionIntoParent(op, rewriter);
268     return success();
269   }
270 };
271 } // namespace
272 
273 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
274                                              MLIRContext *context) {
275   // If taking a passing witness, inline region.
276   patterns.insert<AssumingWithTrue>(context);
277 }
278 
279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
280 void AssumingOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   // AssumingOp has unconditional control flow into the region and back to the
284   // parent, so return the correct RegionSuccessor purely based on the index
285   // being None or 0.
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&doRegion()));
292 }
293 
294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
295                                         PatternRewriter &rewriter) {
296   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
297   auto *assumingBlock = op.getBody();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *blockAfterAssuming =
300       rewriter.splitBlock(blockBeforeAssuming, initPosition);
301 
302   // Remove the AssumingOp and AssumingYieldOp.
303   auto &yieldOp = assumingBlock->back();
304   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
305   rewriter.replaceOp(op, yieldOp.getOperands());
306   rewriter.eraseOp(&yieldOp);
307 
308   // Merge blocks together as there was no branching behavior from the
309   // AssumingOp.
310   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
311   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // AssumingAllOp
316 //===----------------------------------------------------------------------===//
317 
318 void AssumingAllOp::getCanonicalizationPatterns(
319     OwningRewritePatternList &patterns, MLIRContext *context) {
320   patterns.insert<AssumingAllOneOp>(context);
321 }
322 
323 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
324   // Iterate in reverse to first handle all constant operands. They are
325   // guaranteed to be the tail of the inputs because this is commutative.
326   for (int idx = operands.size() - 1; idx >= 0; idx--) {
327     Attribute a = operands[idx];
328     // Cannot fold if any inputs are not constant;
329     if (!a)
330       return nullptr;
331 
332     // We do not need to keep statically known values after handling them in
333     // this method.
334     getOperation()->eraseOperand(idx);
335 
336     // Always false if any input is statically known false
337     if (!a.cast<BoolAttr>().getValue())
338       return a;
339   }
340   // If this is reached, all inputs were statically known passing.
341   return BoolAttr::get(getContext(), true);
342 }
343 
344 static LogicalResult verify(AssumingAllOp op) {
345   // Ensure that AssumingAllOp contains at least one operand
346   if (op.getNumOperands() == 0)
347     return op.emitOpError("no operands specified");
348 
349   return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // BroadcastOp
354 //===----------------------------------------------------------------------===//
355 
356 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
357   if (!operands[1])
358     return nullptr;
359 
360   // TODO: Support folding with more than 2 input shapes
361   if (shapes().size() > 2)
362     return nullptr;
363 
364   auto rhsShape = llvm::to_vector<6>(
365       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
366   if (rhsShape.empty())
367     return shapes()[0];
368 
369   if (!operands[0])
370     return nullptr;
371 
372   auto lhsShape = llvm::to_vector<6>(
373       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
374   if (lhsShape.empty())
375     return shapes()[1];
376 
377   SmallVector<int64_t, 6> resultShape;
378   // If the shapes are not compatible, we can't fold it.
379   // TODO: Fold to an "error".
380   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
381     return nullptr;
382   Builder builder(getContext());
383   return builder.getIndexTensorAttr(resultShape);
384 }
385 
386 static LogicalResult verify(BroadcastOp op) {
387   // Ensure that AssumingAllOp contains at least one operand
388   if (op.getNumOperands() < 2)
389     return op.emitOpError("required at least 2 input shapes");
390 
391   return verifyShapeOrExtentTensorOp(op);
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // ConcatOp
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
399   if (!operands[0] || !operands[1])
400     return nullptr;
401   auto lhsShape = llvm::to_vector<6>(
402       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
403   auto rhsShape = llvm::to_vector<6>(
404       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
405   SmallVector<int64_t, 6> resultShape;
406   resultShape.append(lhsShape.begin(), lhsShape.end());
407   resultShape.append(rhsShape.begin(), rhsShape.end());
408   Builder builder(getContext());
409   return builder.getIndexTensorAttr(resultShape);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // ConstShapeOp
414 //===----------------------------------------------------------------------===//
415 
416 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
417   p << "shape.const_shape ";
418   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
419   p << "[";
420   interleaveComma(op.shape().getValues<int64_t>(), p,
421                   [&](int64_t i) { p << i; });
422   p << "] : ";
423   p.printType(op.getType());
424 }
425 
426 static ParseResult parseConstShapeOp(OpAsmParser &parser,
427                                      OperationState &result) {
428   if (parser.parseOptionalAttrDict(result.attributes))
429     return failure();
430   // We piggy-back on ArrayAttr parsing, though we don't internally store the
431   // shape as an ArrayAttr.
432   // TODO: Implement custom parser and maybe make syntax a bit more concise.
433   Attribute extentsRaw;
434   NamedAttrList dummy;
435   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
436     return failure();
437   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
438   if (!extentsArray)
439     return failure();
440   SmallVector<int64_t, 6> ints;
441   for (Attribute extent : extentsArray) {
442     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
443     if (!attr)
444       return failure();
445     ints.push_back(attr.getInt());
446   }
447   Builder &builder = parser.getBuilder();
448   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
449   Type resultTy;
450   if (parser.parseColonType(resultTy))
451     return failure();
452   result.types.push_back(resultTy);
453   return success();
454 }
455 
456 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
457 
458 void ConstShapeOp::getCanonicalizationPatterns(
459     OwningRewritePatternList &patterns, MLIRContext *context) {
460   patterns.insert<TensorCastConstShape>(context);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // CstrBroadcastableOp
465 //===----------------------------------------------------------------------===//
466 
467 namespace {
468 // Given an input shape Value, try to obtain the shape's values.
469 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
470   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
471     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
472     if (!type.hasRank())
473       return failure();
474     shapeValues = llvm::to_vector<6>(type.getShape());
475     return success();
476   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
477     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
478     return success();
479   } else {
480     return failure();
481   }
482 }
483 } // namespace
484 
485 void CstrBroadcastableOp::getCanonicalizationPatterns(
486     OwningRewritePatternList &patterns, MLIRContext *context) {
487   // Canonicalization patterns have overlap with the considerations during
488   // folding in case additional shape information is inferred at some point that
489   // does not result in folding.
490   patterns.insert<CstrBroadcastableEqOps>(context);
491 }
492 
493 // Return true if there is exactly one attribute not representing a scalar
494 // broadcast.
495 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
496   bool nonScalarSeen = false;
497   for (Attribute a : attributes) {
498     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
499       if (nonScalarSeen)
500         return false;
501       nonScalarSeen = true;
502     }
503   }
504   return true;
505 }
506 
507 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
508   // No broadcasting is needed if all operands but one are scalar.
509   if (hasAtMostSingleNonScalar(operands))
510     return BoolAttr::get(getContext(), true);
511 
512   if ([&] {
513         SmallVector<SmallVector<int64_t, 6>, 6> extents;
514         for (const auto &operand : operands) {
515           if (!operand)
516             return false;
517           extents.push_back(llvm::to_vector<6>(
518               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
519         }
520         return OpTrait::util::staticallyKnownBroadcastable(extents);
521       }())
522     return BoolAttr::get(getContext(), true);
523 
524   // Lastly, see if folding can be completed based on what constraints are known
525   // on the input shapes.
526   if ([&] {
527         SmallVector<SmallVector<int64_t, 6>, 6> extents;
528         for (const auto &shape : shapes()) {
529           extents.emplace_back();
530           if (failed(getShapeVec(shape, extents.back())))
531             return false;
532         }
533         return OpTrait::util::staticallyKnownBroadcastable(extents);
534       }())
535     return BoolAttr::get(getContext(), true);
536 
537   // Because a failing witness result here represents an eventual assertion
538   // failure, we do not replace it with a constant witness.
539   return nullptr;
540 }
541 
542 static LogicalResult verify(CstrBroadcastableOp op) {
543   // Ensure that AssumingAllOp contains at least one operand
544   if (op.getNumOperands() < 2)
545     return op.emitOpError("required at least 2 input shapes");
546   return success();
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // CstrEqOp
551 //===----------------------------------------------------------------------===//
552 
553 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
554                                            MLIRContext *context) {
555   // If inputs are equal, return passing witness
556   patterns.insert<CstrEqEqOps>(context);
557 }
558 
559 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
560   if (llvm::all_of(operands,
561                    [&](Attribute a) { return a && a == operands[0]; }))
562     return BoolAttr::get(getContext(), true);
563 
564   // Because a failing witness result here represents an eventual assertion
565   // failure, we do not try to replace it with a constant witness. Similarly, we
566   // cannot if there are any non-const inputs.
567   return nullptr;
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // ConstSizeOp
572 //===----------------------------------------------------------------------===//
573 
574 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
575                         int64_t value) {
576   build(builder, result, builder.getIndexAttr(value));
577 }
578 
579 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
580 
581 void ConstSizeOp::getAsmResultNames(
582     llvm::function_ref<void(Value, StringRef)> setNameFn) {
583   SmallString<4> buffer;
584   llvm::raw_svector_ostream os(buffer);
585   os << "c" << value();
586   setNameFn(getResult(), os.str());
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // ConstWitnessOp
591 //===----------------------------------------------------------------------===//
592 
593 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
594 
595 //===----------------------------------------------------------------------===//
596 // CstrRequireOp
597 //===----------------------------------------------------------------------===//
598 
599 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
600   return operands[0];
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // DivOp
605 //===----------------------------------------------------------------------===//
606 
607 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
608   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
609   if (!lhs)
610     return nullptr;
611   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
612   if (!rhs)
613     return nullptr;
614 
615   // Division in APInt does not follow floor(lhs, rhs) when the result is
616   // negative. Rather, APInt rounds toward zero.
617   APInt quotient, remainder;
618   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
619   if (quotient.isNegative() && !remainder.isNullValue()) {
620     quotient -= 1;
621   }
622 
623   Type indexTy = IndexType::get(getContext());
624   return IntegerAttr::get(indexTy, quotient);
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // ShapeEqOp
629 //===----------------------------------------------------------------------===//
630 
631 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
632   bool allSame = true;
633   if (!operands.empty() && !operands[0])
634     return {};
635   for (Attribute operand : operands.drop_front(1)) {
636     if (!operand)
637       return {};
638     allSame = allSame && operand == operands[0];
639   }
640   return BoolAttr::get(getContext(), allSame);
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // IndexToSizeOp
645 //===----------------------------------------------------------------------===//
646 
647 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
648   // Constant values of both types, `shape.size` and `index`, are represented as
649   // `IntegerAttr`s which makes constant folding simple.
650   if (Attribute arg = operands[0])
651     return arg;
652   return {};
653 }
654 
655 void IndexToSizeOp::getCanonicalizationPatterns(
656     OwningRewritePatternList &patterns, MLIRContext *context) {
657   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // FromExtentsOp
662 //===----------------------------------------------------------------------===//
663 
664 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
665   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
666     return nullptr;
667   SmallVector<int64_t, 6> extents;
668   for (auto attr : operands)
669     extents.push_back(attr.cast<IntegerAttr>().getInt());
670   Builder builder(getContext());
671   return builder.getIndexTensorAttr(extents);
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // FunctionLibraryOp
676 //===----------------------------------------------------------------------===//
677 
678 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
679                               StringRef name) {
680   ensureTerminator(*result.addRegion(), builder, result.location);
681   result.attributes.push_back(builder.getNamedAttr(
682       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
683 }
684 
685 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
686   auto attr = mapping()
687                   .get(op->getName().getIdentifier())
688                   .dyn_cast_or_null<FlatSymbolRefAttr>();
689   if (!attr)
690     return nullptr;
691   return lookupSymbol<FuncOp>(attr);
692 }
693 
694 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
695                                    OperationState &result) {
696   // Parse the op name.
697   StringAttr nameAttr;
698   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
699                              result.attributes))
700     return failure();
701 
702   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
703     return failure();
704 
705   auto *bodyRegion = result.addRegion();
706   if (parser.parseRegion(*bodyRegion))
707     return failure();
708 
709   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
710                                       result.location);
711   if (parser.parseKeyword("mapping"))
712     return failure();
713 
714   DictionaryAttr mappingAttr;
715   if (parser.parseAttribute(mappingAttr,
716                             parser.getBuilder().getType<NoneType>(), "mapping",
717                             result.attributes))
718     return failure();
719   return success();
720 }
721 
722 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
723   p << op.getOperationName() << ' ';
724   p.printSymbolName(op.getName());
725   p.printOptionalAttrDictWithKeyword(
726       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
727   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
728                 /*printBlockTerminators=*/false);
729   p << " mapping ";
730   p.printAttributeWithoutType(op.mappingAttr());
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // GetExtentOp
735 //===----------------------------------------------------------------------===//
736 
737 Optional<int64_t> GetExtentOp::getConstantDim() {
738   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
739     return constSizeOp.value().getLimitedValue();
740   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
741     return constantOp.value().cast<IntegerAttr>().getInt();
742   return llvm::None;
743 }
744 
745 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
746   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
747   if (!elements)
748     return nullptr;
749   Optional<int64_t> dim = getConstantDim();
750   if (!dim.hasValue())
751     return nullptr;
752   if (dim.getValue() >= elements.getNumElements())
753     return nullptr;
754   return elements.getValue({(uint64_t)dim.getValue()});
755 }
756 
757 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
758                         int64_t dim) {
759   auto loc = result.location;
760   auto dimAttr = builder.getIndexAttr(dim);
761   if (shape.getType().isa<ShapeType>()) {
762     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
763     build(builder, result, builder.getType<SizeType>(), shape, dim);
764   } else {
765     Value dim =
766         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
767     build(builder, result, builder.getIndexType(), shape, dim);
768   }
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // IsBroadcastableOp
773 //===----------------------------------------------------------------------===//
774 
775 static LogicalResult verify(IsBroadcastableOp op) {
776   // Ensure that AssumingAllOp contains at least one operand
777   if (op.getNumOperands() < 2)
778     return op.emitOpError("required at least 2 input shapes");
779   return success();
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // RankOp
784 //===----------------------------------------------------------------------===//
785 
786 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
787   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
788   if (!shape)
789     return {};
790   int64_t rank = shape.getNumElements();
791   Builder builder(getContext());
792   return builder.getIndexAttr(rank);
793 }
794 
795 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
796 /// Constant folding fails in cases where only the rank is constant, not the
797 /// shape itself.
798 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
799 ///
800 /// Example:
801 ///
802 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
803 /// %rank = shape.rank %shape
804 ///
805 /// becomes
806 ///
807 /// %rank = shape.const_size 3
808 
809 namespace {
810 struct RankShapeOfCanonicalizationPattern
811     : public OpRewritePattern<shape::RankOp> {
812   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
813 
814   LogicalResult matchAndRewrite(shape::RankOp op,
815                                 PatternRewriter &rewriter) const override {
816     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
817     if (!shapeOfOp)
818       return failure();
819     auto rankedTensorType =
820         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
821     if (!rankedTensorType)
822       return failure();
823     int64_t rank = rankedTensorType.getRank();
824     if (op.getType().isa<IndexType>()) {
825       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
826     } else if (op.getType().isa<shape::SizeType>()) {
827       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
828     } else {
829       return failure();
830     }
831     return success();
832   }
833 };
834 } // namespace
835 
836 void shape::RankOp::getCanonicalizationPatterns(
837     OwningRewritePatternList &patterns, MLIRContext *context) {
838   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // NumElementsOp
843 //===----------------------------------------------------------------------===//
844 
845 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
846 
847   // Fold only when argument constant.
848   Attribute shape = operands[0];
849   if (!shape)
850     return {};
851 
852   APInt product(64, 1);
853   for (auto value : shape.cast<DenseIntElementsAttr>())
854     product *= value;
855   Builder builder(getContext());
856   return builder.getIndexAttr(product.getLimitedValue());
857 }
858 
859 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
860                           Value shape) {
861   if (shape.getType().isa<ShapedType>()) {
862     auto type = builder.getIndexType();
863     return build(builder, result, type, shape);
864   }
865   auto type = SizeType::get(builder.getContext());
866   return build(builder, result, type, shape);
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // MulOp
871 //===----------------------------------------------------------------------===//
872 
873 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
874   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
875   if (!lhs)
876     return nullptr;
877   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
878   if (!rhs)
879     return nullptr;
880   APInt folded = lhs.getValue() * rhs.getValue();
881   Type indexTy = IndexType::get(getContext());
882   return IntegerAttr::get(indexTy, folded);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // ShapeOfOp
887 //===----------------------------------------------------------------------===//
888 
889 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
890   auto type = getOperand().getType().dyn_cast<ShapedType>();
891   if (!type || !type.hasStaticShape())
892     return nullptr;
893   Builder builder(getContext());
894   return builder.getIndexTensorAttr(type.getShape());
895 }
896 
897 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
898   Type type = arg.getType().isa<ShapedType>()
899                   ? (Type)getExtentTensorType(builder.getContext())
900                   : (Type)builder.getType<ShapeType>();
901   return ShapeOfOp::build(builder, result, type, arg);
902 }
903 
904 namespace {
905 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
906   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
907 
908   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
909                                 PatternRewriter &rewriter) const override {
910     if (!op.arg().getType().isa<ShapedType>())
911       return failure();
912     if (op.getType().isa<ShapedType>())
913       return failure();
914 
915     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
916     return success();
917   }
918 };
919 } // namespace
920 
921 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
922                                             MLIRContext *context) {
923   patterns.insert<ShapeOfWithTensor>(context);
924 }
925 
926 //===----------------------------------------------------------------------===//
927 // SizeToIndexOp
928 //===----------------------------------------------------------------------===//
929 
930 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
931   // Constant values of both types, `shape.size` and `index`, are represented as
932   // `IntegerAttr`s which makes constant folding simple.
933   if (Attribute arg = operands[0])
934     return arg;
935   return impl::foldCastOp(*this);
936 }
937 
938 void SizeToIndexOp::getCanonicalizationPatterns(
939     OwningRewritePatternList &patterns, MLIRContext *context) {
940   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // YieldOp
945 //===----------------------------------------------------------------------===//
946 
947 static LogicalResult verify(shape::YieldOp op) {
948   auto *parentOp = op->getParentOp();
949   auto results = parentOp->getResults();
950   auto operands = op.getOperands();
951 
952   if (parentOp->getNumResults() != op.getNumOperands())
953     return op.emitOpError() << "number of operands does not match number of "
954                                "results of its parent";
955   for (auto e : llvm::zip(results, operands))
956     if (std::get<0>(e).getType() != std::get<1>(e).getType())
957       return op.emitOpError()
958              << "types mismatch between yield op and its parent";
959 
960   return success();
961 }
962 
963 //===----------------------------------------------------------------------===//
964 // SplitAtOp
965 //===----------------------------------------------------------------------===//
966 
967 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
968                               SmallVectorImpl<OpFoldResult> &results) {
969   if (!operands[0] || !operands[1])
970     return failure();
971   auto shapeVec = llvm::to_vector<6>(
972       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
973   auto shape = llvm::makeArrayRef(shapeVec);
974   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
975   // Verify that the split point is in the correct range.
976   // TODO: Constant fold to an "error".
977   int64_t rank = shape.size();
978   if (!(-rank <= splitPoint && splitPoint <= rank))
979     return failure();
980   if (splitPoint < 0)
981     splitPoint += shape.size();
982   Builder builder(operands[0].getContext());
983   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
984   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
985   return success();
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // ToExtentTensorOp
990 //===----------------------------------------------------------------------===//
991 
992 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
993   if (!operands[0])
994     return impl::foldCastOp(*this);
995   Builder builder(getContext());
996   auto shape = llvm::to_vector<6>(
997       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
998   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
999                                     builder.getIndexType());
1000   return DenseIntElementsAttr::get(type, shape);
1001 }
1002 
1003 //===----------------------------------------------------------------------===//
1004 // ReduceOp
1005 //===----------------------------------------------------------------------===//
1006 
1007 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1008                      ValueRange initVals) {
1009   result.addOperands(shape);
1010   result.addOperands(initVals);
1011 
1012   Region *bodyRegion = result.addRegion();
1013   bodyRegion->push_back(new Block);
1014   Block &bodyBlock = bodyRegion->front();
1015   bodyBlock.addArgument(builder.getIndexType());
1016 
1017   Type elementType;
1018   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1019     elementType = tensorType.getElementType();
1020   else
1021     elementType = SizeType::get(builder.getContext());
1022   bodyBlock.addArgument(elementType);
1023 
1024   for (Type initValType : initVals.getTypes()) {
1025     bodyBlock.addArgument(initValType);
1026     result.addTypes(initValType);
1027   }
1028 }
1029 
1030 static LogicalResult verify(ReduceOp op) {
1031   // Verify block arg types.
1032   Block &block = op.region().front();
1033 
1034   // The block takes index, extent, and aggregated values as arguments.
1035   auto blockArgsCount = op.initVals().size() + 2;
1036   if (block.getNumArguments() != blockArgsCount)
1037     return op.emitOpError() << "ReduceOp body is expected to have "
1038                             << blockArgsCount << " arguments";
1039 
1040   // The first block argument is the index and must always be of type `index`.
1041   if (!block.getArgument(0).getType().isa<IndexType>())
1042     return op.emitOpError(
1043         "argument 0 of ReduceOp body is expected to be of IndexType");
1044 
1045   // The second block argument is the extent and must be of type `size` or
1046   // `index`, depending on whether the reduce operation is applied to a shape or
1047   // to an extent tensor.
1048   Type extentTy = block.getArgument(1).getType();
1049   if (op.shape().getType().isa<ShapeType>()) {
1050     if (!extentTy.isa<SizeType>())
1051       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1052                             "SizeType if the ReduceOp operates on a ShapeType");
1053   } else {
1054     if (!extentTy.isa<IndexType>())
1055       return op.emitOpError(
1056           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1057           "ReduceOp operates on an extent tensor");
1058   }
1059 
1060   for (auto type : llvm::enumerate(op.initVals()))
1061     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1062       return op.emitOpError()
1063              << "type mismatch between argument " << type.index() + 2
1064              << " of ReduceOp body and initial value " << type.index();
1065   return success();
1066 }
1067 
1068 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1069   // Parse operands.
1070   SmallVector<OpAsmParser::OperandType, 3> operands;
1071   Type shapeOrExtentTensorType;
1072   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1073                               OpAsmParser::Delimiter::Paren) ||
1074       parser.parseColonType(shapeOrExtentTensorType) ||
1075       parser.parseOptionalArrowTypeList(result.types))
1076     return failure();
1077 
1078   // Resolve operands.
1079   auto initVals = llvm::makeArrayRef(operands).drop_front();
1080   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1081                             result.operands) ||
1082       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1083                              result.operands))
1084     return failure();
1085 
1086   // Parse the body.
1087   Region *body = result.addRegion();
1088   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1089     return failure();
1090 
1091   // Parse attributes.
1092   if (parser.parseOptionalAttrDict(result.attributes))
1093     return failure();
1094 
1095   return success();
1096 }
1097 
1098 static void print(OpAsmPrinter &p, ReduceOp op) {
1099   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1100     << ") : " << op.shape().getType();
1101   p.printOptionalArrowTypeList(op.getResultTypes());
1102   p.printRegion(op.region());
1103   p.printOptionalAttrDict(op->getAttrs());
1104 }
1105 
1106 #define GET_OP_CLASSES
1107 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1108