1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
149                                                      NamedAttribute attribute) {
150   // Verify shape.lib attribute.
151   if (attribute.first == "shape.lib") {
152     if (!op->hasTrait<OpTrait::SymbolTable>())
153       return op->emitError(
154           "shape.lib attribute may only be on op implementing SymbolTable");
155 
156     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
157       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
158       if (!symbol)
159         return op->emitError("shape function library ")
160                << symbolRef << " not found";
161       return isa<shape::FunctionLibraryOp>(symbol)
162                  ? success()
163                  : op->emitError()
164                        << symbolRef << " required to be shape function library";
165     }
166 
167     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
168       // Verify all entries are function libraries and mappings in libraries
169       // refer to unique ops.
170       DenseSet<Identifier> key;
171       for (auto it : arr) {
172         if (!it.isa<SymbolRefAttr>())
173           return op->emitError(
174               "only SymbolRefAttr allowed in shape.lib attribute array");
175 
176         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
177             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
178         if (!shapeFnLib)
179           return op->emitError()
180                  << it << " does not refer to FunctionLibraryOp";
181         for (auto mapping : shapeFnLib.mapping()) {
182           if (!key.insert(mapping.first).second) {
183             return op->emitError("only one op to shape mapping allowed, found "
184                                  "multiple for `")
185                    << mapping.first << "`";
186           }
187         }
188       }
189       return success();
190     }
191 
192     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
193                          "allowed as shape.lib attribute");
194   }
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // AnyOp
200 //===----------------------------------------------------------------------===//
201 
202 // TODO: Canonicalization should be implemented for shapes that can be
203 // determined through mixtures of the known dimensions of the inputs.
204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
205   // Only the last operand is checked because AnyOp is commutative.
206   if (operands.back())
207     return operands.back();
208 
209   return nullptr;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AssumingOp
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseAssumingOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   result.regions.reserve(1);
219   Region *doRegion = result.addRegion();
220 
221   auto &builder = parser.getBuilder();
222   OpAsmParser::OperandType cond;
223   if (parser.parseOperand(cond) ||
224       parser.resolveOperand(cond, builder.getType<WitnessType>(),
225                             result.operands))
226     return failure();
227 
228   // Parse optional results type list.
229   if (parser.parseOptionalArrowTypeList(result.types))
230     return failure();
231 
232   // Parse the region and add a terminator if elided.
233   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
234     return failure();
235   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
236 
237   // Parse the optional attribute list.
238   if (parser.parseOptionalAttrDict(result.attributes))
239     return failure();
240   return success();
241 }
242 
243 static void print(OpAsmPrinter &p, AssumingOp op) {
244   bool yieldsResults = !op.results().empty();
245 
246   p << AssumingOp::getOperationName() << " " << op.witness();
247   if (yieldsResults) {
248     p << " -> (" << op.getResultTypes() << ")";
249   }
250   p.printRegion(op.doRegion(),
251                 /*printEntryBlockArgs=*/false,
252                 /*printBlockTerminators=*/yieldsResults);
253   p.printOptionalAttrDict(op->getAttrs());
254 }
255 
256 namespace {
257 // Removes AssumingOp with a passing witness and inlines the region.
258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
259   using OpRewritePattern<AssumingOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(AssumingOp op,
262                                 PatternRewriter &rewriter) const override {
263     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
264     if (!witness || !witness.passingAttr())
265       return failure();
266 
267     AssumingOp::inlineRegionIntoParent(op, rewriter);
268     return success();
269   }
270 };
271 } // namespace
272 
273 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
274                                              MLIRContext *context) {
275   // If taking a passing witness, inline region.
276   patterns.add<AssumingWithTrue>(context);
277 }
278 
279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
280 void AssumingOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   // AssumingOp has unconditional control flow into the region and back to the
284   // parent, so return the correct RegionSuccessor purely based on the index
285   // being None or 0.
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&doRegion()));
292 }
293 
294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
295                                         PatternRewriter &rewriter) {
296   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
297   auto *assumingBlock = op.getBody();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *blockAfterAssuming =
300       rewriter.splitBlock(blockBeforeAssuming, initPosition);
301 
302   // Remove the AssumingOp and AssumingYieldOp.
303   auto &yieldOp = assumingBlock->back();
304   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
305   rewriter.replaceOp(op, yieldOp.getOperands());
306   rewriter.eraseOp(&yieldOp);
307 
308   // Merge blocks together as there was no branching behavior from the
309   // AssumingOp.
310   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
311   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // AssumingAllOp
316 //===----------------------------------------------------------------------===//
317 
318 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
319                                                 MLIRContext *context) {
320   patterns.add<AssumingAllOneOp>(context);
321 }
322 
323 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
324   // Iterate in reverse to first handle all constant operands. They are
325   // guaranteed to be the tail of the inputs because this is commutative.
326   for (int idx = operands.size() - 1; idx >= 0; idx--) {
327     Attribute a = operands[idx];
328     // Cannot fold if any inputs are not constant;
329     if (!a)
330       return nullptr;
331 
332     // We do not need to keep statically known values after handling them in
333     // this method.
334     getOperation()->eraseOperand(idx);
335 
336     // Always false if any input is statically known false
337     if (!a.cast<BoolAttr>().getValue())
338       return a;
339   }
340   // If this is reached, all inputs were statically known passing.
341   return BoolAttr::get(getContext(), true);
342 }
343 
344 static LogicalResult verify(AssumingAllOp op) {
345   // Ensure that AssumingAllOp contains at least one operand
346   if (op.getNumOperands() == 0)
347     return op.emitOpError("no operands specified");
348 
349   return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // BroadcastOp
354 //===----------------------------------------------------------------------===//
355 
356 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
357   if (operands.size() == 1)
358     return shapes().front();
359 
360   // TODO: Support folding with more than 2 input shapes
361   if (shapes().size() > 2)
362     return nullptr;
363 
364   if (!operands[1])
365     return nullptr;
366 
367   auto rhsShape = llvm::to_vector<6>(
368       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
369   if (rhsShape.empty())
370     return shapes()[0];
371 
372   if (!operands[0])
373     return nullptr;
374 
375   auto lhsShape = llvm::to_vector<6>(
376       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
377   if (lhsShape.empty())
378     return shapes()[1];
379 
380   SmallVector<int64_t, 6> resultShape;
381   // If the shapes are not compatible, we can't fold it.
382   // TODO: Fold to an "error".
383   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
384     return nullptr;
385   Builder builder(getContext());
386   return builder.getIndexTensorAttr(resultShape);
387 }
388 
389 static LogicalResult verify(BroadcastOp op) {
390   return verifyShapeOrExtentTensorOp(op);
391 }
392 
393 namespace {
394 template <typename OpTy>
395 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
396   using OpRewritePattern<OpTy>::OpRewritePattern;
397 
398   LogicalResult matchAndRewrite(OpTy op,
399                                 PatternRewriter &rewriter) const override {
400     // Find unique operands.
401     SmallVector<Value, 2> unique;
402     for (Value v : op.getOperands()) {
403       if (!llvm::is_contained(unique, v))
404         unique.push_back(v);
405     }
406 
407     // Reduce op to equivalent with unique operands.
408     if (unique.size() < op.getNumOperands()) {
409       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
410                                         op->getAttrs());
411       return success();
412     }
413 
414     return failure();
415   }
416 };
417 
418 struct BroadcastForwardSingleOperandPattern
419     : public OpRewritePattern<BroadcastOp> {
420   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
421 
422   LogicalResult matchAndRewrite(BroadcastOp op,
423                                 PatternRewriter &rewriter) const override {
424     if (op.getNumOperands() == 1) {
425       rewriter.replaceOp(op, op.shapes().front());
426       return success();
427     }
428     return failure();
429   }
430 };
431 } // namespace
432 
433 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
434                                               MLIRContext *context) {
435   patterns.add<BroadcastForwardSingleOperandPattern,
436                RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // ConcatOp
441 //===----------------------------------------------------------------------===//
442 
443 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
444   if (!operands[0] || !operands[1])
445     return nullptr;
446   auto lhsShape = llvm::to_vector<6>(
447       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
448   auto rhsShape = llvm::to_vector<6>(
449       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
450   SmallVector<int64_t, 6> resultShape;
451   resultShape.append(lhsShape.begin(), lhsShape.end());
452   resultShape.append(rhsShape.begin(), rhsShape.end());
453   Builder builder(getContext());
454   return builder.getIndexTensorAttr(resultShape);
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // ConstShapeOp
459 //===----------------------------------------------------------------------===//
460 
461 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
462   p << "shape.const_shape ";
463   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
464   p << "[";
465   interleaveComma(op.shape().getValues<int64_t>(), p,
466                   [&](int64_t i) { p << i; });
467   p << "] : ";
468   p.printType(op.getType());
469 }
470 
471 static ParseResult parseConstShapeOp(OpAsmParser &parser,
472                                      OperationState &result) {
473   if (parser.parseOptionalAttrDict(result.attributes))
474     return failure();
475   // We piggy-back on ArrayAttr parsing, though we don't internally store the
476   // shape as an ArrayAttr.
477   // TODO: Implement custom parser and maybe make syntax a bit more concise.
478   Attribute extentsRaw;
479   NamedAttrList dummy;
480   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
481     return failure();
482   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
483   if (!extentsArray)
484     return failure();
485   SmallVector<int64_t, 6> ints;
486   for (Attribute extent : extentsArray) {
487     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
488     if (!attr)
489       return failure();
490     ints.push_back(attr.getInt());
491   }
492   Builder &builder = parser.getBuilder();
493   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
494   Type resultTy;
495   if (parser.parseColonType(resultTy))
496     return failure();
497   result.types.push_back(resultTy);
498   return success();
499 }
500 
501 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
502 
503 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
504                                                MLIRContext *context) {
505   patterns.add<TensorCastConstShape>(context);
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // CstrBroadcastableOp
510 //===----------------------------------------------------------------------===//
511 
512 namespace {
513 // Given an input shape Value, try to obtain the shape's values.
514 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
515   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
516     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
517     if (!type.hasRank())
518       return failure();
519     shapeValues = llvm::to_vector<6>(type.getShape());
520     return success();
521   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
522     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
523     return success();
524   } else {
525     return failure();
526   }
527 }
528 } // namespace
529 
530 void CstrBroadcastableOp::getCanonicalizationPatterns(
531     RewritePatternSet &patterns, MLIRContext *context) {
532   // Canonicalization patterns have overlap with the considerations during
533   // folding in case additional shape information is inferred at some point that
534   // does not result in folding.
535   patterns.add<CstrBroadcastableEqOps>(context);
536 }
537 
538 // Return true if there is exactly one attribute not representing a scalar
539 // broadcast.
540 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
541   bool nonScalarSeen = false;
542   for (Attribute a : attributes) {
543     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
544       if (nonScalarSeen)
545         return false;
546       nonScalarSeen = true;
547     }
548   }
549   return true;
550 }
551 
552 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
553   // No broadcasting is needed if all operands but one are scalar.
554   if (hasAtMostSingleNonScalar(operands))
555     return BoolAttr::get(getContext(), true);
556 
557   if ([&] {
558         SmallVector<SmallVector<int64_t, 6>, 6> extents;
559         for (const auto &operand : operands) {
560           if (!operand)
561             return false;
562           extents.push_back(llvm::to_vector<6>(
563               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
564         }
565         return OpTrait::util::staticallyKnownBroadcastable(extents);
566       }())
567     return BoolAttr::get(getContext(), true);
568 
569   // Lastly, see if folding can be completed based on what constraints are known
570   // on the input shapes.
571   if ([&] {
572         SmallVector<SmallVector<int64_t, 6>, 6> extents;
573         for (auto shapeValue : shapes()) {
574           extents.emplace_back();
575           if (failed(getShapeVec(shapeValue, extents.back())))
576             return false;
577         }
578         return OpTrait::util::staticallyKnownBroadcastable(extents);
579       }())
580     return BoolAttr::get(getContext(), true);
581 
582   // Because a failing witness result here represents an eventual assertion
583   // failure, we do not replace it with a constant witness.
584   return nullptr;
585 }
586 
587 static LogicalResult verify(CstrBroadcastableOp op) {
588   // Ensure that AssumingAllOp contains at least one operand
589   if (op.getNumOperands() < 2)
590     return op.emitOpError("required at least 2 input shapes");
591   return success();
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // CstrEqOp
596 //===----------------------------------------------------------------------===//
597 
598 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
599                                            MLIRContext *context) {
600   // If inputs are equal, return passing witness
601   patterns.add<CstrEqEqOps>(context);
602 }
603 
604 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
605   if (llvm::all_of(operands,
606                    [&](Attribute a) { return a && a == operands[0]; }))
607     return BoolAttr::get(getContext(), true);
608 
609   // Because a failing witness result here represents an eventual assertion
610   // failure, we do not try to replace it with a constant witness. Similarly, we
611   // cannot if there are any non-const inputs.
612   return nullptr;
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // ConstSizeOp
617 //===----------------------------------------------------------------------===//
618 
619 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
620                         int64_t value) {
621   build(builder, result, builder.getIndexAttr(value));
622 }
623 
624 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
625 
626 void ConstSizeOp::getAsmResultNames(
627     llvm::function_ref<void(Value, StringRef)> setNameFn) {
628   SmallString<4> buffer;
629   llvm::raw_svector_ostream os(buffer);
630   os << "c" << value();
631   setNameFn(getResult(), os.str());
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // ConstWitnessOp
636 //===----------------------------------------------------------------------===//
637 
638 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
639 
640 //===----------------------------------------------------------------------===//
641 // CstrRequireOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
645   return operands[0];
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // DivOp
650 //===----------------------------------------------------------------------===//
651 
652 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
653   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
654   if (!lhs)
655     return nullptr;
656   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
657   if (!rhs)
658     return nullptr;
659 
660   // Division in APInt does not follow floor(lhs, rhs) when the result is
661   // negative. Rather, APInt rounds toward zero.
662   APInt quotient, remainder;
663   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
664   if (quotient.isNegative() && !remainder.isNullValue()) {
665     quotient -= 1;
666   }
667 
668   Type indexTy = IndexType::get(getContext());
669   return IntegerAttr::get(indexTy, quotient);
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // ShapeEqOp
674 //===----------------------------------------------------------------------===//
675 
676 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
677   bool allSame = true;
678   if (!operands.empty() && !operands[0])
679     return {};
680   for (Attribute operand : operands.drop_front(1)) {
681     if (!operand)
682       return {};
683     allSame = allSame && operand == operands[0];
684   }
685   return BoolAttr::get(getContext(), allSame);
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // IndexToSizeOp
690 //===----------------------------------------------------------------------===//
691 
692 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
693   // Constant values of both types, `shape.size` and `index`, are represented as
694   // `IntegerAttr`s which makes constant folding simple.
695   if (Attribute arg = operands[0])
696     return arg;
697   return {};
698 }
699 
700 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
701                                                 MLIRContext *context) {
702   patterns.add<SizeToIndexToSizeCanonicalization>(context);
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // FromExtentsOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
710   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
711     return nullptr;
712   SmallVector<int64_t, 6> extents;
713   for (auto attr : operands)
714     extents.push_back(attr.cast<IntegerAttr>().getInt());
715   Builder builder(getContext());
716   return builder.getIndexTensorAttr(extents);
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // FunctionLibraryOp
721 //===----------------------------------------------------------------------===//
722 
723 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
724                               StringRef name) {
725   ensureTerminator(*result.addRegion(), builder, result.location);
726   result.attributes.push_back(builder.getNamedAttr(
727       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
728 }
729 
730 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
731   auto attr = mapping()
732                   .get(op->getName().getIdentifier())
733                   .dyn_cast_or_null<FlatSymbolRefAttr>();
734   if (!attr)
735     return nullptr;
736   return lookupSymbol<FuncOp>(attr);
737 }
738 
739 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
740                                    OperationState &result) {
741   // Parse the op name.
742   StringAttr nameAttr;
743   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
744                              result.attributes))
745     return failure();
746 
747   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
748     return failure();
749 
750   auto *bodyRegion = result.addRegion();
751   if (parser.parseRegion(*bodyRegion))
752     return failure();
753 
754   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
755                                       result.location);
756   if (parser.parseKeyword("mapping"))
757     return failure();
758 
759   DictionaryAttr mappingAttr;
760   if (parser.parseAttribute(mappingAttr,
761                             parser.getBuilder().getType<NoneType>(), "mapping",
762                             result.attributes))
763     return failure();
764   return success();
765 }
766 
767 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
768   p << op.getOperationName() << ' ';
769   p.printSymbolName(op.getName());
770   p.printOptionalAttrDictWithKeyword(
771       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
772   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
773                 /*printBlockTerminators=*/false);
774   p << " mapping ";
775   p.printAttributeWithoutType(op.mappingAttr());
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // GetExtentOp
780 //===----------------------------------------------------------------------===//
781 
782 Optional<int64_t> GetExtentOp::getConstantDim() {
783   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
784     return constSizeOp.value().getLimitedValue();
785   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
786     return constantOp.value().cast<IntegerAttr>().getInt();
787   return llvm::None;
788 }
789 
790 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
791   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
792   if (!elements)
793     return nullptr;
794   Optional<int64_t> dim = getConstantDim();
795   if (!dim.hasValue())
796     return nullptr;
797   if (dim.getValue() >= elements.getNumElements())
798     return nullptr;
799   return elements.getValue({(uint64_t)dim.getValue()});
800 }
801 
802 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
803                         int64_t dim) {
804   auto loc = result.location;
805   auto dimAttr = builder.getIndexAttr(dim);
806   if (shape.getType().isa<ShapeType>()) {
807     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
808     build(builder, result, builder.getType<SizeType>(), shape, dim);
809   } else {
810     Value dim =
811         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
812     build(builder, result, builder.getIndexType(), shape, dim);
813   }
814 }
815 
816 //===----------------------------------------------------------------------===//
817 // IsBroadcastableOp
818 //===----------------------------------------------------------------------===//
819 
820 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
821                                                     MLIRContext *context) {
822   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
823 }
824 
825 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
826   // Can always broadcast fewer than two shapes.
827   if (operands.size() < 2) {
828     return BoolAttr::get(getContext(), true);
829   }
830 
831   return nullptr;
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // RankOp
836 //===----------------------------------------------------------------------===//
837 
838 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
839   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
840   if (!shape)
841     return {};
842   int64_t rank = shape.getNumElements();
843   Builder builder(getContext());
844   return builder.getIndexAttr(rank);
845 }
846 
847 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
848 /// Constant folding fails in cases where only the rank is constant, not the
849 /// shape itself.
850 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
851 ///
852 /// Example:
853 ///
854 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
855 /// %rank = shape.rank %shape
856 ///
857 /// becomes
858 ///
859 /// %rank = shape.const_size 3
860 
861 namespace {
862 struct RankShapeOfCanonicalizationPattern
863     : public OpRewritePattern<shape::RankOp> {
864   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
865 
866   LogicalResult matchAndRewrite(shape::RankOp op,
867                                 PatternRewriter &rewriter) const override {
868     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
869     if (!shapeOfOp)
870       return failure();
871     auto rankedTensorType =
872         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
873     if (!rankedTensorType)
874       return failure();
875     int64_t rank = rankedTensorType.getRank();
876     if (op.getType().isa<IndexType>()) {
877       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
878     } else if (op.getType().isa<shape::SizeType>()) {
879       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
880     } else {
881       return failure();
882     }
883     return success();
884   }
885 };
886 } // namespace
887 
888 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
889                                                 MLIRContext *context) {
890   patterns.add<RankShapeOfCanonicalizationPattern>(context);
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // NumElementsOp
895 //===----------------------------------------------------------------------===//
896 
897 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
898 
899   // Fold only when argument constant.
900   Attribute shape = operands[0];
901   if (!shape)
902     return {};
903 
904   APInt product(64, 1);
905   for (auto value : shape.cast<DenseIntElementsAttr>())
906     product *= value;
907   Builder builder(getContext());
908   return builder.getIndexAttr(product.getLimitedValue());
909 }
910 
911 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
912                           Value shape) {
913   if (shape.getType().isa<ShapedType>()) {
914     auto type = builder.getIndexType();
915     return build(builder, result, type, shape);
916   }
917   auto type = SizeType::get(builder.getContext());
918   return build(builder, result, type, shape);
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // MulOp
923 //===----------------------------------------------------------------------===//
924 
925 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
926   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
927   if (!lhs)
928     return nullptr;
929   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
930   if (!rhs)
931     return nullptr;
932   APInt folded = lhs.getValue() * rhs.getValue();
933   Type indexTy = IndexType::get(getContext());
934   return IntegerAttr::get(indexTy, folded);
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // ShapeOfOp
939 //===----------------------------------------------------------------------===//
940 
941 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
942   auto type = getOperand().getType().dyn_cast<ShapedType>();
943   if (!type || !type.hasStaticShape())
944     return nullptr;
945   Builder builder(getContext());
946   return builder.getIndexTensorAttr(type.getShape());
947 }
948 
949 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
950   Type type = arg.getType().isa<ShapedType>()
951                   ? (Type)getExtentTensorType(builder.getContext())
952                   : (Type)builder.getType<ShapeType>();
953   return ShapeOfOp::build(builder, result, type, arg);
954 }
955 
956 namespace {
957 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
958   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
959 
960   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
961                                 PatternRewriter &rewriter) const override {
962     if (!op.arg().getType().isa<ShapedType>())
963       return failure();
964     if (op.getType().isa<ShapedType>())
965       return failure();
966 
967     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
968     return success();
969   }
970 };
971 } // namespace
972 
973 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
974                                             MLIRContext *context) {
975   patterns.add<ShapeOfWithTensor>(context);
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // SizeToIndexOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
983   // Constant values of both types, `shape.size` and `index`, are represented as
984   // `IntegerAttr`s which makes constant folding simple.
985   if (Attribute arg = operands[0])
986     return arg;
987   return impl::foldCastOp(*this);
988 }
989 
990 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
991                                                 MLIRContext *context) {
992   patterns.add<IndexToSizeToIndexCanonicalization>(context);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // YieldOp
997 //===----------------------------------------------------------------------===//
998 
999 static LogicalResult verify(shape::YieldOp op) {
1000   auto *parentOp = op->getParentOp();
1001   auto results = parentOp->getResults();
1002   auto operands = op.getOperands();
1003 
1004   if (parentOp->getNumResults() != op.getNumOperands())
1005     return op.emitOpError() << "number of operands does not match number of "
1006                                "results of its parent";
1007   for (auto e : llvm::zip(results, operands))
1008     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1009       return op.emitOpError()
1010              << "types mismatch between yield op and its parent";
1011 
1012   return success();
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // SplitAtOp
1017 //===----------------------------------------------------------------------===//
1018 
1019 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1020                               SmallVectorImpl<OpFoldResult> &results) {
1021   if (!operands[0] || !operands[1])
1022     return failure();
1023   auto shapeVec = llvm::to_vector<6>(
1024       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1025   auto shape = llvm::makeArrayRef(shapeVec);
1026   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1027   // Verify that the split point is in the correct range.
1028   // TODO: Constant fold to an "error".
1029   int64_t rank = shape.size();
1030   if (!(-rank <= splitPoint && splitPoint <= rank))
1031     return failure();
1032   if (splitPoint < 0)
1033     splitPoint += shape.size();
1034   Builder builder(operands[0].getContext());
1035   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1036   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1037   return success();
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // ToExtentTensorOp
1042 //===----------------------------------------------------------------------===//
1043 
1044 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1045   if (!operands[0])
1046     return impl::foldCastOp(*this);
1047   Builder builder(getContext());
1048   auto shape = llvm::to_vector<6>(
1049       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1050   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1051                                     builder.getIndexType());
1052   return DenseIntElementsAttr::get(type, shape);
1053 }
1054 
1055 //===----------------------------------------------------------------------===//
1056 // ReduceOp
1057 //===----------------------------------------------------------------------===//
1058 
1059 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1060                      ValueRange initVals) {
1061   result.addOperands(shape);
1062   result.addOperands(initVals);
1063 
1064   Region *bodyRegion = result.addRegion();
1065   bodyRegion->push_back(new Block);
1066   Block &bodyBlock = bodyRegion->front();
1067   bodyBlock.addArgument(builder.getIndexType());
1068 
1069   Type elementType;
1070   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1071     elementType = tensorType.getElementType();
1072   else
1073     elementType = SizeType::get(builder.getContext());
1074   bodyBlock.addArgument(elementType);
1075 
1076   for (Type initValType : initVals.getTypes()) {
1077     bodyBlock.addArgument(initValType);
1078     result.addTypes(initValType);
1079   }
1080 }
1081 
1082 static LogicalResult verify(ReduceOp op) {
1083   // Verify block arg types.
1084   Block &block = op.region().front();
1085 
1086   // The block takes index, extent, and aggregated values as arguments.
1087   auto blockArgsCount = op.initVals().size() + 2;
1088   if (block.getNumArguments() != blockArgsCount)
1089     return op.emitOpError() << "ReduceOp body is expected to have "
1090                             << blockArgsCount << " arguments";
1091 
1092   // The first block argument is the index and must always be of type `index`.
1093   if (!block.getArgument(0).getType().isa<IndexType>())
1094     return op.emitOpError(
1095         "argument 0 of ReduceOp body is expected to be of IndexType");
1096 
1097   // The second block argument is the extent and must be of type `size` or
1098   // `index`, depending on whether the reduce operation is applied to a shape or
1099   // to an extent tensor.
1100   Type extentTy = block.getArgument(1).getType();
1101   if (op.shape().getType().isa<ShapeType>()) {
1102     if (!extentTy.isa<SizeType>())
1103       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1104                             "SizeType if the ReduceOp operates on a ShapeType");
1105   } else {
1106     if (!extentTy.isa<IndexType>())
1107       return op.emitOpError(
1108           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1109           "ReduceOp operates on an extent tensor");
1110   }
1111 
1112   for (auto type : llvm::enumerate(op.initVals()))
1113     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1114       return op.emitOpError()
1115              << "type mismatch between argument " << type.index() + 2
1116              << " of ReduceOp body and initial value " << type.index();
1117   return success();
1118 }
1119 
1120 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1121   // Parse operands.
1122   SmallVector<OpAsmParser::OperandType, 3> operands;
1123   Type shapeOrExtentTensorType;
1124   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1125                               OpAsmParser::Delimiter::Paren) ||
1126       parser.parseColonType(shapeOrExtentTensorType) ||
1127       parser.parseOptionalArrowTypeList(result.types))
1128     return failure();
1129 
1130   // Resolve operands.
1131   auto initVals = llvm::makeArrayRef(operands).drop_front();
1132   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1133                             result.operands) ||
1134       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1135                              result.operands))
1136     return failure();
1137 
1138   // Parse the body.
1139   Region *body = result.addRegion();
1140   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1141     return failure();
1142 
1143   // Parse attributes.
1144   if (parser.parseOptionalAttrDict(result.attributes))
1145     return failure();
1146 
1147   return success();
1148 }
1149 
1150 static void print(OpAsmPrinter &p, ReduceOp op) {
1151   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1152     << ") : " << op.shape().getType();
1153   p.printOptionalArrowTypeList(op.getResultTypes());
1154   p.printRegion(op.region());
1155   p.printOptionalAttrDict(op->getAttrs());
1156 }
1157 
1158 #define GET_OP_CLASSES
1159 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1160