1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
149                                                      NamedAttribute attribute) {
150   // Verify shape.lib attribute.
151   if (attribute.first == "shape.lib") {
152     if (!op->hasTrait<OpTrait::SymbolTable>())
153       return op->emitError(
154           "shape.lib attribute may only be on op implementing SymbolTable");
155 
156     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
157       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
158       if (!symbol)
159         return op->emitError("shape function library ")
160                << symbolRef << " not found";
161       return isa<shape::FunctionLibraryOp>(symbol)
162                  ? success()
163                  : op->emitError()
164                        << symbolRef << " required to be shape function library";
165     }
166 
167     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
168       // Verify all entries are function libraries and mappings in libraries
169       // refer to unique ops.
170       DenseSet<Identifier> key;
171       for (auto it : arr) {
172         if (!it.isa<SymbolRefAttr>())
173           return op->emitError(
174               "only SymbolRefAttr allowed in shape.lib attribute array");
175 
176         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
177             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
178         if (!shapeFnLib)
179           return op->emitError()
180                  << it << " does not refer to FunctionLibraryOp";
181         for (auto mapping : shapeFnLib.mapping()) {
182           if (!key.insert(mapping.first).second) {
183             return op->emitError("only one op to shape mapping allowed, found "
184                                  "multiple for `")
185                    << mapping.first << "`";
186           }
187         }
188       }
189       return success();
190     }
191 
192     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
193                          "allowed as shape.lib attribute");
194   }
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // AnyOp
200 //===----------------------------------------------------------------------===//
201 
202 // TODO: Canonicalization should be implemented for shapes that can be
203 // determined through mixtures of the known dimensions of the inputs.
204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
205   // Only the last operand is checked because AnyOp is commutative.
206   if (operands.back())
207     return operands.back();
208 
209   return nullptr;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AssumingOp
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseAssumingOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   result.regions.reserve(1);
219   Region *doRegion = result.addRegion();
220 
221   auto &builder = parser.getBuilder();
222   OpAsmParser::OperandType cond;
223   if (parser.parseOperand(cond) ||
224       parser.resolveOperand(cond, builder.getType<WitnessType>(),
225                             result.operands))
226     return failure();
227 
228   // Parse optional results type list.
229   if (parser.parseOptionalArrowTypeList(result.types))
230     return failure();
231 
232   // Parse the region and add a terminator if elided.
233   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
234     return failure();
235   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
236 
237   // Parse the optional attribute list.
238   if (parser.parseOptionalAttrDict(result.attributes))
239     return failure();
240   return success();
241 }
242 
243 static void print(OpAsmPrinter &p, AssumingOp op) {
244   bool yieldsResults = !op.results().empty();
245 
246   p << AssumingOp::getOperationName() << " " << op.witness();
247   if (yieldsResults) {
248     p << " -> (" << op.getResultTypes() << ")";
249   }
250   p.printRegion(op.doRegion(),
251                 /*printEntryBlockArgs=*/false,
252                 /*printBlockTerminators=*/yieldsResults);
253   p.printOptionalAttrDict(op.getAttrs());
254 }
255 
256 namespace {
257 // Removes AssumingOp with a passing witness and inlines the region.
258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
259   using OpRewritePattern<AssumingOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(AssumingOp op,
262                                 PatternRewriter &rewriter) const override {
263     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
264     if (!witness || !witness.passingAttr())
265       return failure();
266 
267     AssumingOp::inlineRegionIntoParent(op, rewriter);
268     return success();
269   }
270 };
271 } // namespace
272 
273 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
274                                              MLIRContext *context) {
275   // If taking a passing witness, inline region.
276   patterns.insert<AssumingWithTrue>(context);
277 }
278 
279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
280 void AssumingOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   // AssumingOp has unconditional control flow into the region and back to the
284   // parent, so return the correct RegionSuccessor purely based on the index
285   // being None or 0.
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&doRegion()));
292 }
293 
294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
295                                         PatternRewriter &rewriter) {
296   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
297   auto *assumingBlock = op.getBody();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *blockAfterAssuming =
300       rewriter.splitBlock(blockBeforeAssuming, initPosition);
301 
302   // Remove the AssumingOp and AssumingYieldOp.
303   auto &yieldOp = assumingBlock->back();
304   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
305   rewriter.replaceOp(op, yieldOp.getOperands());
306   rewriter.eraseOp(&yieldOp);
307 
308   // Merge blocks together as there was no branching behavior from the
309   // AssumingOp.
310   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
311   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // AssumingAllOp
316 //===----------------------------------------------------------------------===//
317 
318 void AssumingAllOp::getCanonicalizationPatterns(
319     OwningRewritePatternList &patterns, MLIRContext *context) {
320   patterns.insert<AssumingAllOneOp>(context);
321 }
322 
323 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
324   // Iterate in reverse to first handle all constant operands. They are
325   // guaranteed to be the tail of the inputs because this is commutative.
326   for (int idx = operands.size() - 1; idx >= 0; idx--) {
327     Attribute a = operands[idx];
328     // Cannot fold if any inputs are not constant;
329     if (!a)
330       return nullptr;
331 
332     // We do not need to keep statically known values after handling them in
333     // this method.
334     getOperation()->eraseOperand(idx);
335 
336     // Always false if any input is statically known false
337     if (!a.cast<BoolAttr>().getValue())
338       return a;
339   }
340   // If this is reached, all inputs were statically known passing.
341   return BoolAttr::get(true, getContext());
342 }
343 
344 static LogicalResult verify(AssumingAllOp op) {
345   // Ensure that AssumingAllOp contains at least one operand
346   if (op.getNumOperands() == 0)
347     return op.emitOpError("no operands specified");
348 
349   return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // BroadcastOp
354 //===----------------------------------------------------------------------===//
355 
356 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
357   if (!operands[1])
358     return nullptr;
359 
360   auto rhsShape = llvm::to_vector<6>(
361       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
362   if (rhsShape.empty())
363     return lhs();
364 
365   if (!operands[0])
366     return nullptr;
367 
368   auto lhsShape = llvm::to_vector<6>(
369       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
370   if (lhsShape.empty())
371     return rhs();
372 
373   SmallVector<int64_t, 6> resultShape;
374   // If the shapes are not compatible, we can't fold it.
375   // TODO: Fold to an "error".
376   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
377     return nullptr;
378   Builder builder(getContext());
379   return builder.getIndexTensorAttr(resultShape);
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // ConcatOp
384 //===----------------------------------------------------------------------===//
385 
386 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
387   if (!operands[0] || !operands[1])
388     return nullptr;
389   auto lhsShape = llvm::to_vector<6>(
390       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
391   auto rhsShape = llvm::to_vector<6>(
392       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
393   SmallVector<int64_t, 6> resultShape;
394   resultShape.append(lhsShape.begin(), lhsShape.end());
395   resultShape.append(rhsShape.begin(), rhsShape.end());
396   Builder builder(getContext());
397   return builder.getIndexTensorAttr(resultShape);
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // ConstShapeOp
402 //===----------------------------------------------------------------------===//
403 
404 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
405   p << "shape.const_shape ";
406   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
407   p << "[";
408   interleaveComma(op.shape().getValues<int64_t>(), p,
409                   [&](int64_t i) { p << i; });
410   p << "] : ";
411   p.printType(op.getType());
412 }
413 
414 static ParseResult parseConstShapeOp(OpAsmParser &parser,
415                                      OperationState &result) {
416   if (parser.parseOptionalAttrDict(result.attributes))
417     return failure();
418   // We piggy-back on ArrayAttr parsing, though we don't internally store the
419   // shape as an ArrayAttr.
420   // TODO: Implement custom parser and maybe make syntax a bit more concise.
421   Attribute extentsRaw;
422   NamedAttrList dummy;
423   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
424     return failure();
425   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
426   if (!extentsArray)
427     return failure();
428   SmallVector<int64_t, 6> ints;
429   for (Attribute extent : extentsArray) {
430     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
431     if (!attr)
432       return failure();
433     ints.push_back(attr.getInt());
434   }
435   Builder &builder = parser.getBuilder();
436   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
437   Type resultTy;
438   if (parser.parseColonType(resultTy))
439     return failure();
440   result.types.push_back(resultTy);
441   return success();
442 }
443 
444 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
445 
446 void ConstShapeOp::getCanonicalizationPatterns(
447     OwningRewritePatternList &patterns, MLIRContext *context) {
448   patterns.insert<TensorCastConstShape>(context);
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // CstrBroadcastableOp
453 //===----------------------------------------------------------------------===//
454 
455 namespace {
456 // Given an input shape Value, try to obtain the shape's values.
457 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
458   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
459     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
460     if (!type.hasRank())
461       return failure();
462     shapeValues = llvm::to_vector<6>(type.getShape());
463     return success();
464   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
465     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
466     return success();
467   } else {
468     return failure();
469   }
470 }
471 } // namespace
472 
473 void CstrBroadcastableOp::getCanonicalizationPatterns(
474     OwningRewritePatternList &patterns, MLIRContext *context) {
475   // Canonicalization patterns have overlap with the considerations during
476   // folding in case additional shape information is inferred at some point that
477   // does not result in folding.
478   patterns.insert<CstrBroadcastableEqOps>(context);
479 }
480 
481 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
482   // Both operands are not needed if one is a scalar.
483   if (operands[0] &&
484       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
485     return BoolAttr::get(true, getContext());
486   if (operands[1] &&
487       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
488     return BoolAttr::get(true, getContext());
489 
490   if (operands[0] && operands[1]) {
491     auto lhsShape = llvm::to_vector<6>(
492         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
493     auto rhsShape = llvm::to_vector<6>(
494         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
495     SmallVector<int64_t, 6> resultShape;
496     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
497       return BoolAttr::get(true, getContext());
498   }
499 
500   // Lastly, see if folding can be completed based on what constraints are known
501   // on the input shapes.
502   SmallVector<int64_t, 6> lhsShape, rhsShape;
503   if (failed(getShapeVec(lhs(), lhsShape)))
504     return nullptr;
505   if (failed(getShapeVec(rhs(), rhsShape)))
506     return nullptr;
507 
508   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
509     return BoolAttr::get(true, getContext());
510 
511   // Because a failing witness result here represents an eventual assertion
512   // failure, we do not replace it with a constant witness.
513   return nullptr;
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // CstrEqOp
518 //===----------------------------------------------------------------------===//
519 
520 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
521                                            MLIRContext *context) {
522   // If inputs are equal, return passing witness
523   patterns.insert<CstrEqEqOps>(context);
524 }
525 
526 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
527   if (llvm::all_of(operands,
528                    [&](Attribute a) { return a && a == operands[0]; }))
529     return BoolAttr::get(true, getContext());
530 
531   // Because a failing witness result here represents an eventual assertion
532   // failure, we do not try to replace it with a constant witness. Similarly, we
533   // cannot if there are any non-const inputs.
534   return nullptr;
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // ConstSizeOp
539 //===----------------------------------------------------------------------===//
540 
541 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
542                         int64_t value) {
543   build(builder, result, builder.getIndexAttr(value));
544 }
545 
546 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
547 
548 void ConstSizeOp::getAsmResultNames(
549     llvm::function_ref<void(Value, StringRef)> setNameFn) {
550   SmallString<4> buffer;
551   llvm::raw_svector_ostream os(buffer);
552   os << "c" << value();
553   setNameFn(getResult(), os.str());
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // ConstWitnessOp
558 //===----------------------------------------------------------------------===//
559 
560 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
561 
562 //===----------------------------------------------------------------------===//
563 // CstrRequireOp
564 //===----------------------------------------------------------------------===//
565 
566 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
567   return operands[0];
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // ShapeEqOp
572 //===----------------------------------------------------------------------===//
573 
574 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
575   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
576   if (lhs == nullptr)
577     return {};
578   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
579   if (rhs == nullptr)
580     return {};
581   return BoolAttr::get(lhs == rhs, getContext());
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // IndexToSizeOp
586 //===----------------------------------------------------------------------===//
587 
588 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
589   // Constant values of both types, `shape.size` and `index`, are represented as
590   // `IntegerAttr`s which makes constant folding simple.
591   if (Attribute arg = operands[0])
592     return arg;
593   return {};
594 }
595 
596 void IndexToSizeOp::getCanonicalizationPatterns(
597     OwningRewritePatternList &patterns, MLIRContext *context) {
598   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // FromExtentsOp
603 //===----------------------------------------------------------------------===//
604 
605 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
606   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
607     return nullptr;
608   SmallVector<int64_t, 6> extents;
609   for (auto attr : operands)
610     extents.push_back(attr.cast<IntegerAttr>().getInt());
611   Builder builder(getContext());
612   return builder.getIndexTensorAttr(extents);
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // FunctionLibraryOp
617 //===----------------------------------------------------------------------===//
618 
619 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
620                               StringRef name) {
621   ensureTerminator(*result.addRegion(), builder, result.location);
622   result.attributes.push_back(builder.getNamedAttr(
623       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
624 }
625 
626 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
627   auto attr = mapping()
628                   .get(op->getName().getIdentifier())
629                   .dyn_cast_or_null<FlatSymbolRefAttr>();
630   if (!attr)
631     return nullptr;
632   return lookupSymbol<FuncOp>(attr);
633 }
634 
635 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
636                                    OperationState &result) {
637   // Parse the op name.
638   StringAttr nameAttr;
639   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
640                              result.attributes))
641     return failure();
642 
643   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
644     return failure();
645 
646   auto *bodyRegion = result.addRegion();
647   if (parser.parseRegion(*bodyRegion))
648     return failure();
649 
650   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
651                                       result.location);
652   if (parser.parseKeyword("mapping"))
653     return failure();
654 
655   DictionaryAttr mappingAttr;
656   if (parser.parseAttribute(mappingAttr,
657                             parser.getBuilder().getType<NoneType>(), "mapping",
658                             result.attributes))
659     return failure();
660   return success();
661 }
662 
663 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
664   p << op.getOperationName() << ' ';
665   p.printSymbolName(op.getName());
666   p.printOptionalAttrDictWithKeyword(
667       op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
668   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
669                 /*printBlockTerminators=*/false);
670   p << " mapping ";
671   p.printAttributeWithoutType(op.mappingAttr());
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // GetExtentOp
676 //===----------------------------------------------------------------------===//
677 
678 Optional<int64_t> GetExtentOp::getConstantDim() {
679   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
680     return constSizeOp.value().getLimitedValue();
681   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
682     return constantOp.value().cast<IntegerAttr>().getInt();
683   return llvm::None;
684 }
685 
686 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
687   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
688   if (!elements)
689     return nullptr;
690   Optional<int64_t> dim = getConstantDim();
691   if (!dim.hasValue())
692     return nullptr;
693   if (dim.getValue() >= elements.getNumElements())
694     return nullptr;
695   return elements.getValue({(uint64_t)dim.getValue()});
696 }
697 
698 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
699                         int64_t dim) {
700   auto loc = result.location;
701   auto dimAttr = builder.getIndexAttr(dim);
702   if (shape.getType().isa<ShapeType>()) {
703     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
704     build(builder, result, builder.getType<SizeType>(), shape, dim);
705   } else {
706     Value dim =
707         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
708     build(builder, result, builder.getIndexType(), shape, dim);
709   }
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // RankOp
714 //===----------------------------------------------------------------------===//
715 
716 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
717   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
718   if (!shape)
719     return {};
720   int64_t rank = shape.getNumElements();
721   Builder builder(getContext());
722   return builder.getIndexAttr(rank);
723 }
724 
725 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
726 /// Constant folding fails in cases where only the rank is constant, not the
727 /// shape itself.
728 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
729 ///
730 /// Example:
731 ///
732 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
733 /// %rank = shape.rank %shape
734 ///
735 /// becomes
736 ///
737 /// %rank = shape.const_size 3
738 
739 namespace {
740 struct RankShapeOfCanonicalizationPattern
741     : public OpRewritePattern<shape::RankOp> {
742   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
743 
744   LogicalResult matchAndRewrite(shape::RankOp op,
745                                 PatternRewriter &rewriter) const override {
746     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
747     if (!shapeOfOp)
748       return failure();
749     auto rankedTensorType =
750         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
751     if (!rankedTensorType)
752       return failure();
753     int64_t rank = rankedTensorType.getRank();
754     if (op.getType().isa<IndexType>()) {
755       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
756     } else if (op.getType().isa<shape::SizeType>()) {
757       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
758     } else {
759       return failure();
760     }
761     return success();
762   }
763 };
764 } // namespace
765 
766 void shape::RankOp::getCanonicalizationPatterns(
767     OwningRewritePatternList &patterns, MLIRContext *context) {
768   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // NumElementsOp
773 //===----------------------------------------------------------------------===//
774 
775 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
776 
777   // Fold only when argument constant.
778   Attribute shape = operands[0];
779   if (!shape)
780     return {};
781 
782   APInt product(64, 1);
783   for (auto value : shape.cast<DenseIntElementsAttr>())
784     product *= value;
785   Builder builder(getContext());
786   return builder.getIndexAttr(product.getLimitedValue());
787 }
788 
789 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
790                           Value shape) {
791   if (shape.getType().isa<ShapedType>()) {
792     auto type = builder.getIndexType();
793     return build(builder, result, type, shape);
794   }
795   auto type = SizeType::get(builder.getContext());
796   return build(builder, result, type, shape);
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // MulOp
801 //===----------------------------------------------------------------------===//
802 
803 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
804   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
805   if (!lhs)
806     return nullptr;
807   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
808   if (!rhs)
809     return nullptr;
810   APInt folded = lhs.getValue() * rhs.getValue();
811   Type indexTy = IndexType::get(getContext());
812   return IntegerAttr::get(indexTy, folded);
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // ShapeOfOp
817 //===----------------------------------------------------------------------===//
818 
819 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
820   auto type = getOperand().getType().dyn_cast<ShapedType>();
821   if (!type || !type.hasStaticShape())
822     return nullptr;
823   Builder builder(getContext());
824   return builder.getIndexTensorAttr(type.getShape());
825 }
826 
827 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
828   Type type = arg.getType().isa<ShapedType>()
829                   ? (Type)getExtentTensorType(builder.getContext())
830                   : (Type)builder.getType<ShapeType>();
831   return ShapeOfOp::build(builder, result, type, arg);
832 }
833 
834 namespace {
835 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
836   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
837 
838   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
839                                 PatternRewriter &rewriter) const override {
840     if (!op.arg().getType().isa<ShapedType>())
841       return failure();
842     if (op.getType().isa<ShapedType>())
843       return failure();
844 
845     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
846     return success();
847   }
848 };
849 } // namespace
850 
851 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
852                                             MLIRContext *context) {
853   patterns.insert<ShapeOfWithTensor>(context);
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // SizeToIndexOp
858 //===----------------------------------------------------------------------===//
859 
860 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
861   // Constant values of both types, `shape.size` and `index`, are represented as
862   // `IntegerAttr`s which makes constant folding simple.
863   if (Attribute arg = operands[0])
864     return arg;
865   return impl::foldCastOp(*this);
866 }
867 
868 void SizeToIndexOp::getCanonicalizationPatterns(
869     OwningRewritePatternList &patterns, MLIRContext *context) {
870   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // YieldOp
875 //===----------------------------------------------------------------------===//
876 
877 static LogicalResult verify(shape::YieldOp op) {
878   auto *parentOp = op->getParentOp();
879   auto results = parentOp->getResults();
880   auto operands = op.getOperands();
881 
882   if (parentOp->getNumResults() != op.getNumOperands())
883     return op.emitOpError() << "number of operands does not match number of "
884                                "results of its parent";
885   for (auto e : llvm::zip(results, operands))
886     if (std::get<0>(e).getType() != std::get<1>(e).getType())
887       return op.emitOpError()
888              << "types mismatch between yield op and its parent";
889 
890   return success();
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // SplitAtOp
895 //===----------------------------------------------------------------------===//
896 
897 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
898                               SmallVectorImpl<OpFoldResult> &results) {
899   if (!operands[0] || !operands[1])
900     return failure();
901   auto shapeVec = llvm::to_vector<6>(
902       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
903   auto shape = llvm::makeArrayRef(shapeVec);
904   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
905   // Verify that the split point is in the correct range.
906   // TODO: Constant fold to an "error".
907   int64_t rank = shape.size();
908   if (!(-rank <= splitPoint && splitPoint <= rank))
909     return failure();
910   if (splitPoint < 0)
911     splitPoint += shape.size();
912   Builder builder(operands[0].getContext());
913   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
914   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
915   return success();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // ToExtentTensorOp
920 //===----------------------------------------------------------------------===//
921 
922 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
923   if (!operands[0])
924     return impl::foldCastOp(*this);
925   Builder builder(getContext());
926   auto shape = llvm::to_vector<6>(
927       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
928   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
929                                     builder.getIndexType());
930   return DenseIntElementsAttr::get(type, shape);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // ReduceOp
935 //===----------------------------------------------------------------------===//
936 
937 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
938                      ValueRange initVals) {
939   result.addOperands(shape);
940   result.addOperands(initVals);
941 
942   Region *bodyRegion = result.addRegion();
943   bodyRegion->push_back(new Block);
944   Block &bodyBlock = bodyRegion->front();
945   bodyBlock.addArgument(builder.getIndexType());
946 
947   Type elementType;
948   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
949     elementType = tensorType.getElementType();
950   else
951     elementType = SizeType::get(builder.getContext());
952   bodyBlock.addArgument(elementType);
953 
954   for (Type initValType : initVals.getTypes()) {
955     bodyBlock.addArgument(initValType);
956     result.addTypes(initValType);
957   }
958 }
959 
960 static LogicalResult verify(ReduceOp op) {
961   // Verify block arg types.
962   Block &block = op.region().front();
963 
964   // The block takes index, extent, and aggregated values as arguments.
965   auto blockArgsCount = op.initVals().size() + 2;
966   if (block.getNumArguments() != blockArgsCount)
967     return op.emitOpError() << "ReduceOp body is expected to have "
968                             << blockArgsCount << " arguments";
969 
970   // The first block argument is the index and must always be of type `index`.
971   if (!block.getArgument(0).getType().isa<IndexType>())
972     return op.emitOpError(
973         "argument 0 of ReduceOp body is expected to be of IndexType");
974 
975   // The second block argument is the extent and must be of type `size` or
976   // `index`, depending on whether the reduce operation is applied to a shape or
977   // to an extent tensor.
978   Type extentTy = block.getArgument(1).getType();
979   if (op.shape().getType().isa<ShapeType>()) {
980     if (!extentTy.isa<SizeType>())
981       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
982                             "SizeType if the ReduceOp operates on a ShapeType");
983   } else {
984     if (!extentTy.isa<IndexType>())
985       return op.emitOpError(
986           "argument 1 of ReduceOp body is expected to be of IndexType if the "
987           "ReduceOp operates on an extent tensor");
988   }
989 
990   for (auto type : llvm::enumerate(op.initVals()))
991     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
992       return op.emitOpError()
993              << "type mismatch between argument " << type.index() + 2
994              << " of ReduceOp body and initial value " << type.index();
995   return success();
996 }
997 
998 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
999   // Parse operands.
1000   SmallVector<OpAsmParser::OperandType, 3> operands;
1001   Type shapeOrExtentTensorType;
1002   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1003                               OpAsmParser::Delimiter::Paren) ||
1004       parser.parseColonType(shapeOrExtentTensorType) ||
1005       parser.parseOptionalArrowTypeList(result.types))
1006     return failure();
1007 
1008   // Resolve operands.
1009   auto initVals = llvm::makeArrayRef(operands).drop_front();
1010   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1011                             result.operands) ||
1012       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1013                              result.operands))
1014     return failure();
1015 
1016   // Parse the body.
1017   Region *body = result.addRegion();
1018   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1019     return failure();
1020 
1021   // Parse attributes.
1022   if (parser.parseOptionalAttrDict(result.attributes))
1023     return failure();
1024 
1025   return success();
1026 }
1027 
1028 static void print(OpAsmPrinter &p, ReduceOp op) {
1029   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1030     << ") : " << op.shape().getType();
1031   p.printOptionalArrowTypeList(op.getResultTypes());
1032   p.printRegion(op.region());
1033   p.printOptionalAttrDict(op.getAttrs());
1034 }
1035 
1036 #define GET_OP_CLASSES
1037 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1038