1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 using namespace mlir::shape;
32 
33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
34 
35 namespace {
36 #include "ShapeCanonicalization.inc"
37 } // namespace
38 
getExtentTensorType(MLIRContext * ctx,int64_t rank)39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
40   return RankedTensorType::get({rank}, IndexType::get(ctx));
41 }
42 
isExtentTensorType(Type type)43 bool shape::isExtentTensorType(Type type) {
44   auto ranked = type.dyn_cast<RankedTensorType>();
45   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
46 }
47 
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)48 LogicalResult shape::getShapeVec(Value input,
49                                  SmallVectorImpl<int64_t> &shapeValues) {
50   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
51     auto type = inputOp.getArg().getType().cast<ShapedType>();
52     if (!type.hasRank())
53       return failure();
54     llvm::append_range(shapeValues, type.getShape());
55     return success();
56   }
57   DenseIntElementsAttr attr;
58   if (matchPattern(input, m_Constant(&attr))) {
59     llvm::append_range(shapeValues, attr.getValues<int64_t>());
60     return success();
61   }
62   return failure();
63 }
64 
isErrorPropagationPossible(TypeRange operandTypes)65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66   return llvm::any_of(operandTypes, [](Type ty) {
67     return ty.isa<SizeType, ShapeType, ValueShapeType>();
68   });
69 }
70 
verifySizeOrIndexOp(Operation * op)71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72   assert(op != nullptr && op->getNumResults() == 1);
73   Type resultTy = op->getResultTypes().front();
74   if (isErrorPropagationPossible(op->getOperandTypes())) {
75     if (!resultTy.isa<SizeType>())
76       return op->emitOpError()
77              << "if at least one of the operands can hold error values then "
78                 "the result must be of type `size` to propagate them";
79   }
80   return success();
81 }
82 
verifyShapeOrExtentTensorOp(Operation * op)83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84   assert(op != nullptr && op->getNumResults() == 1);
85   Type resultTy = op->getResultTypes().front();
86   if (isErrorPropagationPossible(op->getOperandTypes())) {
87     if (!resultTy.isa<ShapeType>())
88       return op->emitOpError()
89              << "if at least one of the operands can hold error values then "
90                 "the result must be of type `shape` to propagate them";
91   }
92   return success();
93 }
94 
95 template <typename... Ty>
eachHasOnlyOneOfTypes(TypeRange typeRange)96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
eachHasOnlyOneOfTypes(TypeRange l,ranges...rs)101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   // Returns true if the given region 'src' can be inlined into the region
115   // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonde1778d90311::ShapeInlinerInterface116   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   // Returns true if the given operation 'op', that is registered to this
122   // dialect, can be inlined into the region 'dest' that is attached to an
123   // operation registered to the current dialect.
isLegalToInline__anonde1778d90311::ShapeInlinerInterface124   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125                        BlockAndValueMapping &) const final {
126     return true;
127   }
128 };
129 } // namespace
130 
initialize()131 void ShapeDialect::initialize() {
132   addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135       >();
136   addTypes<
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139       >();
140   addInterfaces<ShapeInlinerInterface>();
141   // Allow unknown operations during prototyping and testing. As the dialect is
142   // still evolving it makes it simple to start with an unregistered ops and
143   // try different variants before actually defining the op.
144   allowUnknownOperations();
145 }
146 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
148                                              Attribute value, Type type,
149                                              Location loc) {
150   if (type.isa<ShapeType>() || isExtentTensorType(type))
151     return builder.create<ConstShapeOp>(loc, type,
152                                         value.cast<DenseIntElementsAttr>());
153   if (type.isa<SizeType>())
154     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155   if (type.isa<WitnessType>())
156     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157   if (arith::ConstantOp::isBuildableWith(value, type))
158     return builder.create<arith::ConstantOp>(loc, type, value);
159   return nullptr;
160 }
161 
verifyOperationAttribute(Operation * op,NamedAttribute attribute)162 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
163                                                      NamedAttribute attribute) {
164   // Verify shape.lib attribute.
165   if (attribute.getName() == "shape.lib") {
166     if (!op->hasTrait<OpTrait::SymbolTable>())
167       return op->emitError(
168           "shape.lib attribute may only be on op implementing SymbolTable");
169 
170     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
171       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
172       if (!symbol)
173         return op->emitError("shape function library ")
174                << symbolRef << " not found";
175       return isa<shape::FunctionLibraryOp>(symbol)
176                  ? success()
177                  : op->emitError()
178                        << symbolRef << " required to be shape function library";
179     }
180 
181     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
182       // Verify all entries are function libraries and mappings in libraries
183       // refer to unique ops.
184       DenseSet<StringAttr> key;
185       for (auto it : arr) {
186         if (!it.isa<SymbolRefAttr>())
187           return op->emitError(
188               "only SymbolRefAttr allowed in shape.lib attribute array");
189 
190         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
191             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
192         if (!shapeFnLib)
193           return op->emitError()
194                  << it << " does not refer to FunctionLibraryOp";
195         for (auto mapping : shapeFnLib.getMapping()) {
196           if (!key.insert(mapping.getName()).second) {
197             return op->emitError("only one op to shape mapping allowed, found "
198                                  "multiple for `")
199                    << mapping.getName() << "`";
200           }
201         }
202       }
203       return success();
204     }
205 
206     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
207                          "allowed as shape.lib attribute");
208   }
209   return success();
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AnyOp
214 //===----------------------------------------------------------------------===//
215 
216 // TODO: Canonicalization should be implemented for shapes that can be
217 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)218 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
219   // Only the last operand is checked because AnyOp is commutative.
220   if (operands.back())
221     return operands.back();
222 
223   return nullptr;
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // AssumingOp
228 //===----------------------------------------------------------------------===//
229 
parse(OpAsmParser & parser,OperationState & result)230 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
231   result.regions.reserve(1);
232   Region *doRegion = result.addRegion();
233 
234   auto &builder = parser.getBuilder();
235   OpAsmParser::UnresolvedOperand cond;
236   if (parser.parseOperand(cond) ||
237       parser.resolveOperand(cond, builder.getType<WitnessType>(),
238                             result.operands))
239     return failure();
240 
241   // Parse optional results type list.
242   if (parser.parseOptionalArrowTypeList(result.types))
243     return failure();
244 
245   // Parse the region and add a terminator if elided.
246   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
247     return failure();
248   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
249 
250   // Parse the optional attribute list.
251   if (parser.parseOptionalAttrDict(result.attributes))
252     return failure();
253   return success();
254 }
255 
print(OpAsmPrinter & p)256 void AssumingOp::print(OpAsmPrinter &p) {
257   bool yieldsResults = !getResults().empty();
258 
259   p << " " << getWitness();
260   if (yieldsResults)
261     p << " -> (" << getResultTypes() << ")";
262   p << ' ';
263   p.printRegion(getDoRegion(),
264                 /*printEntryBlockArgs=*/false,
265                 /*printBlockTerminators=*/yieldsResults);
266   p.printOptionalAttrDict((*this)->getAttrs());
267 }
268 
269 namespace {
270 // Removes AssumingOp with a passing witness and inlines the region.
271 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
272   using OpRewritePattern<AssumingOp>::OpRewritePattern;
273 
matchAndRewrite__anonde1778d90411::AssumingWithTrue274   LogicalResult matchAndRewrite(AssumingOp op,
275                                 PatternRewriter &rewriter) const override {
276     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
277     if (!witness || !witness.getPassingAttr())
278       return failure();
279 
280     AssumingOp::inlineRegionIntoParent(op, rewriter);
281     return success();
282   }
283 };
284 
285 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
286   using OpRewritePattern<AssumingOp>::OpRewritePattern;
287 
matchAndRewrite__anonde1778d90411::AssumingOpRemoveUnusedResults288   LogicalResult matchAndRewrite(AssumingOp op,
289                                 PatternRewriter &rewriter) const override {
290     Block *body = op.getBody();
291     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
292 
293     // Find used values.
294     SmallVector<Value, 4> newYieldOperands;
295     Value opResult, yieldOperand;
296     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
297       std::tie(opResult, yieldOperand) = it;
298       if (!opResult.getUses().empty()) {
299         newYieldOperands.push_back(yieldOperand);
300       }
301     }
302 
303     // Rewrite only if redundant results exist.
304     if (newYieldOperands.size() == yieldOp->getNumOperands())
305       return failure();
306 
307     // Replace yield op in the old assuming op's body and move the entire region
308     // to the new assuming op.
309     rewriter.setInsertionPointToEnd(body);
310     auto newYieldOp =
311         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
312     rewriter.setInsertionPoint(op);
313     auto newOp = rewriter.create<AssumingOp>(
314         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
315     newOp.getDoRegion().takeBody(op.getDoRegion());
316 
317     // Use the new results to replace the previously used ones.
318     SmallVector<Value, 4> replacementValues;
319     auto src = newOp.getResults().begin();
320     for (auto it : op.getResults()) {
321       if (it.getUses().empty())
322         replacementValues.push_back(nullptr);
323       else
324         replacementValues.push_back(*src++);
325     }
326     rewriter.replaceOp(op, replacementValues);
327     return success();
328   }
329 };
330 } // namespace
331 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)332 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
333                                              MLIRContext *context) {
334   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
335 }
336 
337 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)338 void AssumingOp::getSuccessorRegions(
339     Optional<unsigned> index, ArrayRef<Attribute> operands,
340     SmallVectorImpl<RegionSuccessor> &regions) {
341   // AssumingOp has unconditional control flow into the region and back to the
342   // parent, so return the correct RegionSuccessor purely based on the index
343   // being None or 0.
344   if (index) {
345     regions.push_back(RegionSuccessor(getResults()));
346     return;
347   }
348 
349   regions.push_back(RegionSuccessor(&getDoRegion()));
350 }
351 
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)352 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
353                                         PatternRewriter &rewriter) {
354   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
355   auto *assumingBlock = op.getBody();
356   auto initPosition = rewriter.getInsertionPoint();
357   auto *blockAfterAssuming =
358       rewriter.splitBlock(blockBeforeAssuming, initPosition);
359 
360   // Remove the AssumingOp and AssumingYieldOp.
361   auto &yieldOp = assumingBlock->back();
362   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
363   rewriter.replaceOp(op, yieldOp.getOperands());
364   rewriter.eraseOp(&yieldOp);
365 
366   // Merge blocks together as there was no branching behavior from the
367   // AssumingOp.
368   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
369   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
370 }
371 
build(OpBuilder & builder,OperationState & result,Value witness,function_ref<SmallVector<Value,2> (OpBuilder &,Location)> bodyBuilder)372 void AssumingOp::build(
373     OpBuilder &builder, OperationState &result, Value witness,
374     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
375 
376   result.addOperands(witness);
377   Region *bodyRegion = result.addRegion();
378   bodyRegion->push_back(new Block);
379   Block &bodyBlock = bodyRegion->front();
380 
381   // Build body.
382   OpBuilder::InsertionGuard guard(builder);
383   builder.setInsertionPointToStart(&bodyBlock);
384   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
385   builder.create<AssumingYieldOp>(result.location, yieldValues);
386 
387   SmallVector<Type, 2> assumingTypes;
388   for (Value v : yieldValues)
389     assumingTypes.push_back(v.getType());
390   result.addTypes(assumingTypes);
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // AddOp
395 //===----------------------------------------------------------------------===//
396 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)397 LogicalResult mlir::shape::AddOp::inferReturnTypes(
398     MLIRContext *context, Optional<Location> location, ValueRange operands,
399     DictionaryAttr attributes, RegionRange regions,
400     SmallVectorImpl<Type> &inferredReturnTypes) {
401   if (operands[0].getType().isa<SizeType>() ||
402       operands[1].getType().isa<SizeType>())
403     inferredReturnTypes.assign({SizeType::get(context)});
404   else
405     inferredReturnTypes.assign({IndexType::get(context)});
406   return success();
407 }
408 
isCompatibleReturnTypes(TypeRange l,TypeRange r)409 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
410   // SizeType is compatible with IndexType.
411   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
412 }
413 
fold(ArrayRef<Attribute> operands)414 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
415   // add(x, 0) -> x
416   if (matchPattern(getRhs(), m_Zero()))
417     return getLhs();
418 
419   return constFoldBinaryOp<IntegerAttr>(
420       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
421 }
422 
verify()423 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
424 
425 //===----------------------------------------------------------------------===//
426 // AssumingAllOp
427 //===----------------------------------------------------------------------===//
428 
429 namespace {
430 
431 // Merge multiple `shape.assuming_all` operations together.
432 //
433 //   %0 = shape.assuming_all %w0, %w1
434 //   %1 = shape.assuming_all %w2, %0
435 //
436 // to:
437 //
438 //   %0 = shape.assuming_all %w0, %w2, %w2
439 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
440   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
441 
matchAndRewrite__anonde1778d90611::MergeAssumingAllOps442   LogicalResult matchAndRewrite(AssumingAllOp op,
443                                 PatternRewriter &rewriter) const override {
444     SmallVector<Value> operands;
445 
446     for (Value operand : op.getInputs()) {
447       if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
448         operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
449       else
450         operands.push_back(operand);
451     }
452 
453     // We didn't find any other `assuming_all` ops to merge with.
454     if (operands.size() == op.getNumOperands())
455       return failure();
456 
457     // Replace with a new `assuming_all` operation with merged constraints.
458     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
459     return success();
460   }
461 };
462 
463 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
464 // are subsumed by others.
465 //
466 //   %0 = shape.cstr_broadcastable %shape0, %shape1
467 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
468 //
469 //   %2 = shape.cstr_broadcastable %shape3, %shape4
470 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
471 //
472 //   %4 = shape.assuming_all %0, %1, %2, %3
473 //
474 // to:
475 //
476 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
477 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
478 //   %2 = shape.assuming_all %0, %1
479 //
480 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
481 // shapes [0, 1] are broadcastable too, and can be removed from the list of
482 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
483 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
484 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
485   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
486 
matchAndRewrite__anonde1778d90611::AssumingAllOfCstrBroadcastable487   LogicalResult matchAndRewrite(AssumingAllOp op,
488                                 PatternRewriter &rewriter) const override {
489     // Collect all `CstrBroadcastableOp` operands first.
490     SetVector<CstrBroadcastableOp> operands;
491     for (Value operand : op.getInputs()) {
492       // TODO: Apply this optimization if some of the witnesses are not
493       // produced by the `cstr_broadcastable`.
494       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
495       if (!broadcastable)
496         return failure();
497 
498       operands.insert(broadcastable);
499     }
500 
501     // Skip trivial `assuming_all` operations.
502     if (operands.size() <= 1)
503       return failure();
504 
505     // Collect shapes checked by `cstr_broadcastable` operands.
506     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
507     for (auto cstr : operands) {
508       DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
509       shapes.emplace_back(cstr, std::move(shapesSet));
510     }
511 
512     // Sort by the number of shape operands (larger to smaller).
513     llvm::sort(shapes, [](auto a, auto b) {
514       return a.first.getNumOperands() > b.first.getNumOperands();
515     });
516 
517     // We start from the `cst_broadcastable` operations with largest number of
518     // shape operands, and remove redundant `cst_broadcastable` operations. We
519     // do this until we find a set of `cst_broadcastable` operations with
520     // non-overlapping constraints.
521     SmallVector<CstrBroadcastableOp> markedForErase;
522 
523     for (unsigned i = 0; i < shapes.size(); ++i) {
524       auto isSubset = [&](auto pair) {
525         return llvm::set_is_subset(pair.second, shapes[i].second);
526       };
527 
528       // Keep redundant `cstr_broadcastable` operations to be erased.
529       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
530       for (auto *it0 = it; it0 < shapes.end(); ++it0)
531         markedForErase.push_back(it0->first);
532       shapes.erase(it, shapes.end());
533     }
534 
535     // We didn't find any operands that could be removed.
536     if (markedForErase.empty())
537       return failure();
538 
539     // Collect non-overlapping `cst_broadcastable` constraints.
540     SmallVector<Value> uniqueConstraints;
541     for (auto &shape : shapes)
542       uniqueConstraints.push_back(shape.first.getResult());
543 
544     // Replace with a new `assuming_all` operation ...
545     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
546 
547     // ... and maybe erase `cstr_broadcastable` ops without uses.
548     for (auto &op : markedForErase)
549       if (op->use_empty())
550         rewriter.eraseOp(op);
551 
552     return success();
553   }
554 };
555 
556 struct AssumingAllToCstrEqCanonicalization
557     : public OpRewritePattern<AssumingAllOp> {
558   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
559 
matchAndRewrite__anonde1778d90611::AssumingAllToCstrEqCanonicalization560   LogicalResult matchAndRewrite(AssumingAllOp op,
561                                 PatternRewriter &rewriter) const override {
562     SmallVector<Value, 8> shapes;
563     for (Value w : op.getInputs()) {
564       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
565       if (!cstrEqOp)
566         return failure();
567       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
568         return llvm::is_contained(shapes, s);
569       });
570       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
571         return failure();
572       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
573     }
574     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
575     return success();
576   }
577 };
578 
579 template <typename OpTy>
580 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
581   using OpRewritePattern<OpTy>::OpRewritePattern;
582 
matchAndRewrite__anonde1778d90611::RemoveDuplicateOperandsPattern583   LogicalResult matchAndRewrite(OpTy op,
584                                 PatternRewriter &rewriter) const override {
585     // Find unique operands.
586     SetVector<Value> unique(op.operand_begin(), op.operand_end());
587 
588     // Reduce op to equivalent with unique operands.
589     if (unique.size() < op.getNumOperands()) {
590       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
591                                         unique.takeVector(), op->getAttrs());
592       return success();
593     }
594 
595     return failure();
596   }
597 };
598 } // namespace
599 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)600 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
601                                                 MLIRContext *context) {
602   patterns
603       .add<MergeAssumingAllOps, AssumingAllOneOp,
604            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
605            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
606 }
607 
fold(ArrayRef<Attribute> operands)608 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
609   // Iterate in reverse to first handle all constant operands. They are
610   // guaranteed to be the tail of the inputs because this is commutative.
611   for (int idx = operands.size() - 1; idx >= 0; idx--) {
612     Attribute a = operands[idx];
613     // Cannot fold if any inputs are not constant;
614     if (!a)
615       return nullptr;
616 
617     // We do not need to keep statically known values after handling them in
618     // this method.
619     getOperation()->eraseOperand(idx);
620 
621     // Always false if any input is statically known false
622     if (!a.cast<BoolAttr>().getValue())
623       return a;
624   }
625   // If this is reached, all inputs were statically known passing.
626   return BoolAttr::get(getContext(), true);
627 }
628 
verify()629 LogicalResult AssumingAllOp::verify() {
630   // Ensure that AssumingAllOp contains at least one operand
631   if (getNumOperands() == 0)
632     return emitOpError("no operands specified");
633 
634   return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // BroadcastOp
639 //===----------------------------------------------------------------------===//
640 
fold(ArrayRef<Attribute> operands)641 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
642   if (getShapes().size() == 1) {
643     // Otherwise, we need a cast which would be a canonicalization, not folding.
644     if (getShapes().front().getType() != getType())
645       return nullptr;
646     return getShapes().front();
647   }
648 
649   // TODO: Support folding with more than 2 input shapes
650   if (getShapes().size() > 2)
651     return nullptr;
652 
653   if (!operands[0] || !operands[1])
654     return nullptr;
655   auto lhsShape = llvm::to_vector<6>(
656       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
657   auto rhsShape = llvm::to_vector<6>(
658       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
659   SmallVector<int64_t, 6> resultShape;
660 
661   // If the shapes are not compatible, we can't fold it.
662   // TODO: Fold to an "error".
663   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
664     return nullptr;
665 
666   Builder builder(getContext());
667   return builder.getIndexTensorAttr(resultShape);
668 }
669 
verify()670 LogicalResult BroadcastOp::verify() {
671   return verifyShapeOrExtentTensorOp(*this);
672 }
673 
674 namespace {
675 template <typename OpTy>
676 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
677   using OpRewritePattern<OpTy>::OpRewritePattern;
678 
matchAndRewrite__anonde1778d90a11::RemoveEmptyShapeOperandsPattern679   LogicalResult matchAndRewrite(OpTy op,
680                                 PatternRewriter &rewriter) const override {
681     auto isPotentiallyNonEmptyShape = [](Value shape) {
682       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683         if (extentTensorTy.getDimSize(0) == 0)
684           return false;
685       }
686       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
687         if (constShape.getShape().empty())
688           return false;
689       }
690       return true;
691     };
692     auto newOperands = llvm::to_vector<8>(
693         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
694 
695     // Reduce op to equivalent without empty shape operands.
696     if (newOperands.size() < op.getNumOperands()) {
697       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
698                                         op->getAttrs());
699       return success();
700     }
701 
702     return failure();
703   }
704 };
705 
706 struct BroadcastForwardSingleOperandPattern
707     : public OpRewritePattern<BroadcastOp> {
708   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
709 
matchAndRewrite__anonde1778d90a11::BroadcastForwardSingleOperandPattern710   LogicalResult matchAndRewrite(BroadcastOp op,
711                                 PatternRewriter &rewriter) const override {
712     if (op.getNumOperands() != 1)
713       return failure();
714     Value replacement = op.getShapes().front();
715 
716     // Insert cast if needed.
717     if (replacement.getType() != op.getType()) {
718       auto loc = op.getLoc();
719       if (op.getType().isa<ShapeType>()) {
720         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
721       } else {
722         assert(!op.getType().isa<ShapeType>() &&
723                !replacement.getType().isa<ShapeType>() &&
724                "expect extent tensor cast");
725         replacement =
726             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
727       }
728     }
729 
730     rewriter.replaceOp(op, replacement);
731     return success();
732   }
733 };
734 
735 struct BroadcastFoldConstantOperandsPattern
736     : public OpRewritePattern<BroadcastOp> {
737   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
738 
matchAndRewrite__anonde1778d90a11::BroadcastFoldConstantOperandsPattern739   LogicalResult matchAndRewrite(BroadcastOp op,
740                                 PatternRewriter &rewriter) const override {
741     SmallVector<int64_t, 8> foldedConstantShape;
742     SmallVector<Value, 8> newShapeOperands;
743     for (Value shape : op.getShapes()) {
744       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
745         SmallVector<int64_t, 8> newFoldedConstantShape;
746         if (OpTrait::util::getBroadcastedShape(
747                 foldedConstantShape,
748                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
749                 newFoldedConstantShape)) {
750           foldedConstantShape = newFoldedConstantShape;
751           continue;
752         }
753       }
754       newShapeOperands.push_back(shape);
755     }
756 
757     // Need at least two constant operands to fold anything.
758     if (op.getNumOperands() - newShapeOperands.size() < 2)
759       return failure();
760 
761     auto foldedConstantOperandsTy = RankedTensorType::get(
762         {static_cast<int64_t>(foldedConstantShape.size())},
763         rewriter.getIndexType());
764     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
765         op.getLoc(), foldedConstantOperandsTy,
766         rewriter.getIndexTensorAttr(foldedConstantShape)));
767     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
768                                              newShapeOperands);
769     return success();
770   }
771 };
772 
773 template <typename OpTy>
774 struct CanonicalizeCastExtentTensorOperandsPattern
775     : public OpRewritePattern<OpTy> {
776   using OpRewritePattern<OpTy>::OpRewritePattern;
777 
matchAndRewrite__anonde1778d90a11::CanonicalizeCastExtentTensorOperandsPattern778   LogicalResult matchAndRewrite(OpTy op,
779                                 PatternRewriter &rewriter) const override {
780     // Canonicalize operands.
781     bool anyChange = false;
782     auto canonicalizeOperand = [&](Value operand) {
783       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
784         // Only eliminate the cast if it holds no shape information.
785         bool isInformationLoosingCast =
786             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
787         if (isInformationLoosingCast) {
788           anyChange = true;
789           return castOp.getSource();
790         }
791       }
792       return operand;
793     };
794     auto newOperands = llvm::to_vector<8>(
795         llvm::map_range(op.getOperands(), canonicalizeOperand));
796 
797     // Rewrite op if any change required.
798     if (!anyChange)
799       return failure();
800     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
801     return success();
802   }
803 };
804 
805 struct BroadcastConcretizeResultTypePattern
806     : public OpRewritePattern<BroadcastOp> {
807   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
808 
matchAndRewrite__anonde1778d90a11::BroadcastConcretizeResultTypePattern809   LogicalResult matchAndRewrite(BroadcastOp op,
810                                 PatternRewriter &rewriter) const override {
811     // Only concretize dynamic extent tensor result types.
812     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
813     if (!resultTy || !resultTy.isDynamicDim(0))
814       return failure();
815 
816     // Infer resulting shape rank if possible.
817     int64_t maxRank = 0;
818     for (Value shape : op.getShapes()) {
819       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
820         // Cannot infer resulting shape rank if any operand is dynamically
821         // ranked.
822         if (extentTensorTy.isDynamicDim(0))
823           return failure();
824         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
825       }
826     }
827 
828     auto newOp = rewriter.create<BroadcastOp>(
829         op.getLoc(), getExtentTensorType(getContext(), maxRank),
830         op.getShapes());
831     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
832     return success();
833   }
834 };
835 } // namespace
836 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)837 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
838                                               MLIRContext *context) {
839   patterns.add<BroadcastConcretizeResultTypePattern,
840                BroadcastFoldConstantOperandsPattern,
841                BroadcastForwardSingleOperandPattern,
842                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
843                RemoveDuplicateOperandsPattern<BroadcastOp>,
844                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // ConcatOp
849 //===----------------------------------------------------------------------===//
850 
fold(ArrayRef<Attribute> operands)851 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
852   if (!operands[0] || !operands[1])
853     return nullptr;
854   auto lhsShape = llvm::to_vector<6>(
855       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
856   auto rhsShape = llvm::to_vector<6>(
857       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
858   SmallVector<int64_t, 6> resultShape;
859   resultShape.append(lhsShape.begin(), lhsShape.end());
860   resultShape.append(rhsShape.begin(), rhsShape.end());
861   Builder builder(getContext());
862   return builder.getIndexTensorAttr(resultShape);
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // ConstShapeOp
867 //===----------------------------------------------------------------------===//
868 
print(OpAsmPrinter & p)869 void ConstShapeOp::print(OpAsmPrinter &p) {
870   p << " ";
871   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
872   p << "[";
873   interleaveComma(getShape().getValues<int64_t>(), p);
874   p << "] : ";
875   p.printType(getType());
876 }
877 
parse(OpAsmParser & parser,OperationState & result)878 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
879   if (parser.parseOptionalAttrDict(result.attributes))
880     return failure();
881   // We piggy-back on ArrayAttr parsing, though we don't internally store the
882   // shape as an ArrayAttr.
883   // TODO: Implement custom parser and maybe make syntax a bit more concise.
884   Attribute extentsRaw;
885   NamedAttrList dummy;
886   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
887     return failure();
888   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
889   if (!extentsArray)
890     return failure();
891   SmallVector<int64_t, 6> ints;
892   for (Attribute extent : extentsArray) {
893     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
894     if (!attr)
895       return failure();
896     ints.push_back(attr.getInt());
897   }
898   Builder &builder = parser.getBuilder();
899   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
900   Type resultTy;
901   if (parser.parseColonType(resultTy))
902     return failure();
903   result.types.push_back(resultTy);
904   return success();
905 }
906 
fold(ArrayRef<Attribute>)907 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
908 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)909 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
910                                                MLIRContext *context) {
911   patterns.add<TensorCastConstShape>(context);
912 }
913 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)914 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
915     MLIRContext *context, Optional<Location> location, ValueRange operands,
916     DictionaryAttr attributes, RegionRange regions,
917     SmallVectorImpl<Type> &inferredReturnTypes) {
918   Builder b(context);
919   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
920   if (!shape)
921     return emitOptionalError(location, "missing shape attribute");
922   inferredReturnTypes.assign({RankedTensorType::get(
923       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
924   return success();
925 }
926 
isCompatibleReturnTypes(TypeRange l,TypeRange r)927 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
928                                                         TypeRange r) {
929   if (l.size() != 1 || r.size() != 1)
930     return false;
931 
932   Type lhs = l.front();
933   Type rhs = r.front();
934 
935   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
936     // Shape type is compatible with all other valid return types.
937     return true;
938   return lhs == rhs;
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // CstrBroadcastableOp
943 //===----------------------------------------------------------------------===//
944 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)945 void CstrBroadcastableOp::getCanonicalizationPatterns(
946     RewritePatternSet &patterns, MLIRContext *context) {
947   // Canonicalization patterns have overlap with the considerations during
948   // folding in case additional shape information is inferred at some point that
949   // does not result in folding.
950   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
951                CstrBroadcastableEqOps,
952                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
953                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
954 }
955 
956 // Return true if there is exactly one attribute not representing a scalar
957 // broadcast.
hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes)958 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
959   bool nonScalarSeen = false;
960   for (Attribute a : attributes) {
961     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
962       if (nonScalarSeen)
963         return false;
964       nonScalarSeen = true;
965     }
966   }
967   return true;
968 }
969 
fold(ArrayRef<Attribute> operands)970 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
971   // No broadcasting is needed if all operands but one are scalar.
972   if (hasAtMostSingleNonScalar(operands))
973     return BoolAttr::get(getContext(), true);
974 
975   if ([&] {
976         SmallVector<SmallVector<int64_t, 6>, 6> extents;
977         for (const auto &operand : operands) {
978           if (!operand)
979             return false;
980           extents.push_back(llvm::to_vector<6>(
981               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
982         }
983         return OpTrait::util::staticallyKnownBroadcastable(extents);
984       }())
985     return BoolAttr::get(getContext(), true);
986 
987   // Lastly, see if folding can be completed based on what constraints are known
988   // on the input shapes.
989   if ([&] {
990         SmallVector<SmallVector<int64_t, 6>, 6> extents;
991         for (auto shapeValue : getShapes()) {
992           extents.emplace_back();
993           if (failed(getShapeVec(shapeValue, extents.back())))
994             return false;
995         }
996         return OpTrait::util::staticallyKnownBroadcastable(extents);
997       }())
998     return BoolAttr::get(getContext(), true);
999 
1000   // Because a failing witness result here represents an eventual assertion
1001   // failure, we do not replace it with a constant witness.
1002   return nullptr;
1003 }
1004 
verify()1005 LogicalResult CstrBroadcastableOp::verify() {
1006   // Ensure that CstrBroadcastableOp contains at least two operands
1007   if (getNumOperands() < 2)
1008     return emitOpError("required at least 2 input shapes");
1009   return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // CstrEqOp
1014 //===----------------------------------------------------------------------===//
1015 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1016 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1017                                            MLIRContext *context) {
1018   // If inputs are equal, return passing witness
1019   patterns.add<CstrEqEqOps>(context);
1020 }
1021 
fold(ArrayRef<Attribute> operands)1022 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1023   if (llvm::all_of(operands,
1024                    [&](Attribute a) { return a && a == operands[0]; }))
1025     return BoolAttr::get(getContext(), true);
1026 
1027   // Because a failing witness result here represents an eventual assertion
1028   // failure, we do not try to replace it with a constant witness. Similarly, we
1029   // cannot if there are any non-const inputs.
1030   return nullptr;
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // ConstSizeOp
1035 //===----------------------------------------------------------------------===//
1036 
build(OpBuilder & builder,OperationState & result,int64_t value)1037 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1038                         int64_t value) {
1039   build(builder, result, builder.getIndexAttr(value));
1040 }
1041 
fold(ArrayRef<Attribute>)1042 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1043 
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)1044 void ConstSizeOp::getAsmResultNames(
1045     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1046   SmallString<4> buffer;
1047   llvm::raw_svector_ostream os(buffer);
1048   os << "c" << getValue();
1049   setNameFn(getResult(), os.str());
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // ConstWitnessOp
1054 //===----------------------------------------------------------------------===//
1055 
fold(ArrayRef<Attribute>)1056 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1057   return getPassingAttr();
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // CstrRequireOp
1062 //===----------------------------------------------------------------------===//
1063 
fold(ArrayRef<Attribute> operands)1064 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1065   return operands[0];
1066 }
1067 
1068 //===----------------------------------------------------------------------===//
1069 // DivOp
1070 //===----------------------------------------------------------------------===//
1071 
fold(ArrayRef<Attribute> operands)1072 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1073   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1074   if (!lhs)
1075     return nullptr;
1076   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1077   if (!rhs)
1078     return nullptr;
1079 
1080   // Division in APInt does not follow floor(lhs, rhs) when the result is
1081   // negative. Rather, APInt rounds toward zero.
1082   APInt quotient, remainder;
1083   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1084   if (quotient.isNegative() && !remainder.isNullValue()) {
1085     quotient -= 1;
1086   }
1087 
1088   Type indexTy = IndexType::get(getContext());
1089   return IntegerAttr::get(indexTy, quotient);
1090 }
1091 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1092 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1093     MLIRContext *context, Optional<Location> location, ValueRange operands,
1094     DictionaryAttr attributes, RegionRange regions,
1095     SmallVectorImpl<Type> &inferredReturnTypes) {
1096   if (operands[0].getType().isa<SizeType>() ||
1097       operands[1].getType().isa<SizeType>())
1098     inferredReturnTypes.assign({SizeType::get(context)});
1099   else
1100     inferredReturnTypes.assign({IndexType::get(context)});
1101   return success();
1102 }
1103 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1104 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1105   // SizeType is compatible with IndexType.
1106   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1107 }
1108 
verify()1109 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // ShapeEqOp
1113 //===----------------------------------------------------------------------===//
1114 
fold(ArrayRef<Attribute> operands)1115 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1116   bool allSame = true;
1117   if (!operands.empty() && !operands[0])
1118     return {};
1119   for (Attribute operand : operands.drop_front(1)) {
1120     if (!operand)
1121       return {};
1122     allSame = allSame && operand == operands[0];
1123   }
1124   return BoolAttr::get(getContext(), allSame);
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // IndexToSizeOp
1129 //===----------------------------------------------------------------------===//
1130 
fold(ArrayRef<Attribute> operands)1131 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1132   // Constant values of both types, `shape.size` and `index`, are represented as
1133   // `IntegerAttr`s which makes constant folding simple.
1134   if (Attribute arg = operands[0])
1135     return arg;
1136   return {};
1137 }
1138 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1139 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1140                                                 MLIRContext *context) {
1141   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // FromExtentsOp
1146 //===----------------------------------------------------------------------===//
1147 
fold(ArrayRef<Attribute> operands)1148 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1149   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1150     return nullptr;
1151   SmallVector<int64_t, 6> extents;
1152   for (auto attr : operands)
1153     extents.push_back(attr.cast<IntegerAttr>().getInt());
1154   Builder builder(getContext());
1155   return builder.getIndexTensorAttr(extents);
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // FunctionLibraryOp
1160 //===----------------------------------------------------------------------===//
1161 
build(OpBuilder & builder,OperationState & result,StringRef name)1162 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1163                               StringRef name) {
1164   result.attributes.push_back(builder.getNamedAttr(
1165       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1166 }
1167 
getShapeFunction(Operation * op)1168 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1169   auto attr = getMapping()
1170                   .get(op->getName().getIdentifier())
1171                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1172   if (!attr)
1173     return nullptr;
1174   return lookupSymbol<FuncOp>(attr);
1175 }
1176 
parse(OpAsmParser & parser,OperationState & result)1177 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1178                                      OperationState &result) {
1179   // Parse the op name.
1180   StringAttr nameAttr;
1181   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1182                              result.attributes))
1183     return failure();
1184 
1185   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1186     return failure();
1187 
1188   auto *bodyRegion = result.addRegion();
1189   if (parser.parseRegion(*bodyRegion))
1190     return failure();
1191 
1192   if (parser.parseKeyword("mapping"))
1193     return failure();
1194 
1195   DictionaryAttr mappingAttr;
1196   if (parser.parseAttribute(mappingAttr,
1197                             parser.getBuilder().getType<NoneType>(), "mapping",
1198                             result.attributes))
1199     return failure();
1200   return success();
1201 }
1202 
print(OpAsmPrinter & p)1203 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1204   p << ' ';
1205   p.printSymbolName(getName());
1206   p.printOptionalAttrDictWithKeyword(
1207       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1208   p << ' ';
1209   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1210                 /*printBlockTerminators=*/false);
1211   p << " mapping ";
1212   p.printAttributeWithoutType(getMappingAttr());
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // FuncOp
1217 //===----------------------------------------------------------------------===//
1218 
parse(OpAsmParser & parser,OperationState & result)1219 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1220   auto buildFuncType =
1221       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1222          function_interface_impl::VariadicFlag,
1223          std::string &) { return builder.getFunctionType(argTypes, results); };
1224 
1225   return function_interface_impl::parseFunctionOp(
1226       parser, result, /*allowVariadic=*/false, buildFuncType);
1227 }
1228 
print(OpAsmPrinter & p)1229 void FuncOp::print(OpAsmPrinter &p) {
1230   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // GetExtentOp
1235 //===----------------------------------------------------------------------===//
1236 
getConstantDim()1237 Optional<int64_t> GetExtentOp::getConstantDim() {
1238   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1239     return constSizeOp.getValue().getLimitedValue();
1240   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1241     return constantOp.getValue().cast<IntegerAttr>().getInt();
1242   return llvm::None;
1243 }
1244 
fold(ArrayRef<Attribute> operands)1245 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1246   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1247   if (!elements)
1248     return nullptr;
1249   Optional<int64_t> dim = getConstantDim();
1250   if (!dim.has_value())
1251     return nullptr;
1252   if (dim.value() >= elements.getNumElements())
1253     return nullptr;
1254   return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1255 }
1256 
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)1257 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1258                         int64_t dim) {
1259   auto loc = result.location;
1260   auto dimAttr = builder.getIndexAttr(dim);
1261   if (shape.getType().isa<ShapeType>()) {
1262     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1263     build(builder, result, builder.getType<SizeType>(), shape, dim);
1264   } else {
1265     Value dim =
1266         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1267     build(builder, result, builder.getIndexType(), shape, dim);
1268   }
1269 }
1270 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1271 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1272     MLIRContext *context, Optional<Location> location, ValueRange operands,
1273     DictionaryAttr attributes, RegionRange regions,
1274     SmallVectorImpl<Type> &inferredReturnTypes) {
1275   inferredReturnTypes.assign({IndexType::get(context)});
1276   return success();
1277 }
1278 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1279 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1280                                                        TypeRange r) {
1281   // SizeType is compatible with IndexType.
1282   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1283 }
1284 
verify()1285 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1286 
1287 //===----------------------------------------------------------------------===//
1288 // IsBroadcastableOp
1289 //===----------------------------------------------------------------------===//
1290 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1291 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1292                                                     MLIRContext *context) {
1293   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1294 }
1295 
fold(ArrayRef<Attribute> operands)1296 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1297   // Can always broadcast fewer than two shapes.
1298   if (operands.size() < 2) {
1299     return BoolAttr::get(getContext(), true);
1300   }
1301 
1302   return nullptr;
1303 }
1304 
1305 //===----------------------------------------------------------------------===//
1306 // MeetOp
1307 //===----------------------------------------------------------------------===//
1308 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1309 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1310     MLIRContext *context, Optional<Location> location, ValueRange operands,
1311     DictionaryAttr attributes, RegionRange regions,
1312     SmallVectorImpl<Type> &inferredReturnTypes) {
1313   inferredReturnTypes.assign({operands[0].getType()});
1314   return success();
1315 }
1316 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1317 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1318   if (l.size() != 1 || r.size() != 1)
1319     return false;
1320   if (l == r)
1321     return true;
1322 
1323   Type lhs = l.front();
1324   Type rhs = r.front();
1325 
1326   if (lhs != rhs)
1327     return false;
1328 
1329   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1330     return true;
1331 
1332   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1333     return true;
1334   return false;
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // RankOp
1339 //===----------------------------------------------------------------------===//
1340 
fold(ArrayRef<Attribute> operands)1341 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1342   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1343   if (!shape)
1344     return {};
1345   int64_t rank = shape.getNumElements();
1346   Builder builder(getContext());
1347   return builder.getIndexAttr(rank);
1348 }
1349 
1350 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1351 /// Constant folding fails in cases where only the rank is constant, not the
1352 /// shape itself.
1353 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1354 ///
1355 /// Example:
1356 ///
1357 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1358 /// %rank = shape.rank %shape
1359 ///
1360 /// becomes
1361 ///
1362 /// %rank = shape.const_size 3
1363 
1364 namespace {
1365 struct RankShapeOfCanonicalizationPattern
1366     : public OpRewritePattern<shape::RankOp> {
1367   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1368 
matchAndRewrite__anonde1778d91211::RankShapeOfCanonicalizationPattern1369   LogicalResult matchAndRewrite(shape::RankOp op,
1370                                 PatternRewriter &rewriter) const override {
1371     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1372     if (!shapeOfOp)
1373       return failure();
1374     auto rankedTensorType =
1375         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1376     if (!rankedTensorType)
1377       return failure();
1378     int64_t rank = rankedTensorType.getRank();
1379     if (op.getType().isa<IndexType>()) {
1380       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1381                                                           rank);
1382     } else if (op.getType().isa<shape::SizeType>()) {
1383       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1384     } else {
1385       return failure();
1386     }
1387     return success();
1388   }
1389 };
1390 } // namespace
1391 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1392 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1393                                                 MLIRContext *context) {
1394   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1395 }
1396 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1397 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1398     MLIRContext *context, Optional<Location> location, ValueRange operands,
1399     DictionaryAttr attributes, RegionRange regions,
1400     SmallVectorImpl<Type> &inferredReturnTypes) {
1401   if (operands[0].getType().isa<ShapeType>())
1402     inferredReturnTypes.assign({SizeType::get(context)});
1403   else
1404     inferredReturnTypes.assign({IndexType::get(context)});
1405   return success();
1406 }
1407 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1408 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1409   // SizeType is compatible with IndexType.
1410   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1411 }
1412 
verify()1413 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1414 
1415 //===----------------------------------------------------------------------===//
1416 // NumElementsOp
1417 //===----------------------------------------------------------------------===//
1418 
fold(ArrayRef<Attribute> operands)1419 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1420 
1421   // Fold only when argument constant.
1422   Attribute shape = operands[0];
1423   if (!shape)
1424     return {};
1425 
1426   APInt product(64, 1);
1427   for (auto value : shape.cast<DenseIntElementsAttr>())
1428     product *= value;
1429   Builder builder(getContext());
1430   return builder.getIndexAttr(product.getLimitedValue());
1431 }
1432 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1433 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1434     MLIRContext *context, Optional<Location> location, ValueRange operands,
1435     DictionaryAttr attributes, RegionRange regions,
1436     SmallVectorImpl<Type> &inferredReturnTypes) {
1437   if (operands[0].getType().isa<ShapeType>())
1438     inferredReturnTypes.assign({SizeType::get(context)});
1439   else
1440     inferredReturnTypes.assign({IndexType::get(context)});
1441   return success();
1442 }
1443 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1444 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1445                                                          TypeRange r) {
1446   // SizeType is compatible with IndexType.
1447   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1448 }
1449 
verify()1450 LogicalResult shape::NumElementsOp::verify() {
1451   return verifySizeOrIndexOp(*this);
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // MaxOp
1456 //===----------------------------------------------------------------------===//
1457 
fold(llvm::ArrayRef<mlir::Attribute> operands)1458 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1459   // If operands are equal, just propagate one.
1460   if (getLhs() == getRhs())
1461     return getLhs();
1462   return nullptr;
1463 }
1464 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1465 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1466     MLIRContext *context, Optional<Location> location, ValueRange operands,
1467     DictionaryAttr attributes, RegionRange regions,
1468     SmallVectorImpl<Type> &inferredReturnTypes) {
1469   if (operands[0].getType() == operands[1].getType())
1470     inferredReturnTypes.assign({operands[0].getType()});
1471   else
1472     inferredReturnTypes.assign({SizeType::get(context)});
1473   return success();
1474 }
1475 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1476 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1477   if (l.size() != 1 || r.size() != 1)
1478     return false;
1479   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1480     return true;
1481   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1482     return true;
1483   return false;
1484 }
1485 
1486 //===----------------------------------------------------------------------===//
1487 // MinOp
1488 //===----------------------------------------------------------------------===//
1489 
fold(llvm::ArrayRef<mlir::Attribute> operands)1490 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1491   // If operands are equal, just propagate one.
1492   if (getLhs() == getRhs())
1493     return getLhs();
1494   return nullptr;
1495 }
1496 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1497 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1498     MLIRContext *context, Optional<Location> location, ValueRange operands,
1499     DictionaryAttr attributes, RegionRange regions,
1500     SmallVectorImpl<Type> &inferredReturnTypes) {
1501   if (operands[0].getType() == operands[1].getType())
1502     inferredReturnTypes.assign({operands[0].getType()});
1503   else
1504     inferredReturnTypes.assign({SizeType::get(context)});
1505   return success();
1506 }
1507 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1508 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1509   if (l.size() != 1 || r.size() != 1)
1510     return false;
1511   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1512     return true;
1513   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1514     return true;
1515   return false;
1516 }
1517 
1518 //===----------------------------------------------------------------------===//
1519 // MulOp
1520 //===----------------------------------------------------------------------===//
1521 
fold(ArrayRef<Attribute> operands)1522 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1523   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1524   if (!lhs)
1525     return nullptr;
1526   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1527   if (!rhs)
1528     return nullptr;
1529   APInt folded = lhs.getValue() * rhs.getValue();
1530   Type indexTy = IndexType::get(getContext());
1531   return IntegerAttr::get(indexTy, folded);
1532 }
1533 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1534 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1535     MLIRContext *context, Optional<Location> location, ValueRange operands,
1536     DictionaryAttr attributes, RegionRange regions,
1537     SmallVectorImpl<Type> &inferredReturnTypes) {
1538   if (operands[0].getType().isa<SizeType>() ||
1539       operands[1].getType().isa<SizeType>())
1540     inferredReturnTypes.assign({SizeType::get(context)});
1541   else
1542     inferredReturnTypes.assign({IndexType::get(context)});
1543   return success();
1544 }
1545 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1546 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1547   // SizeType is compatible with IndexType.
1548   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1549 }
1550 
verify()1551 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // ShapeOfOp
1555 //===----------------------------------------------------------------------===//
1556 
fold(ArrayRef<Attribute>)1557 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1558   auto type = getOperand().getType().dyn_cast<ShapedType>();
1559   if (!type || !type.hasStaticShape())
1560     return nullptr;
1561   Builder builder(getContext());
1562   return builder.getIndexTensorAttr(type.getShape());
1563 }
1564 
1565 namespace {
1566 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1567   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1568 
matchAndRewrite__anonde1778d91311::ShapeOfWithTensor1569   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1570                                 PatternRewriter &rewriter) const override {
1571     if (!op.getArg().getType().isa<ShapedType>())
1572       return failure();
1573     if (op.getType().isa<ShapedType>())
1574       return failure();
1575 
1576     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1577                                                   op.getArg());
1578     return success();
1579   }
1580 };
1581 
1582 // Canonicalize
1583 // ```
1584 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1585 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1586 // ```
1587 // to
1588 // ```
1589 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1590 // ```
1591 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1592   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1593 
matchAndRewrite__anonde1778d91311::ShapeOfCastExtentTensor1594   LogicalResult matchAndRewrite(tensor::CastOp op,
1595                                 PatternRewriter &rewriter) const override {
1596     auto ty = op.getType().dyn_cast<RankedTensorType>();
1597     if (!ty || ty.getRank() != 1)
1598       return failure();
1599 
1600     auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1601     if (!shapeOfOp)
1602       return failure();
1603 
1604     // Argument type must be ranked and must not conflict.
1605     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1606     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1607       return failure();
1608 
1609     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1610     return success();
1611   }
1612 };
1613 } // namespace
1614 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1615 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1616                                             MLIRContext *context) {
1617   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1618                ExtractFromShapeOfExtentTensor>(context);
1619 }
1620 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1621 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1622     MLIRContext *context, Optional<Location> location, ValueRange operands,
1623     DictionaryAttr attributes, RegionRange regions,
1624     SmallVectorImpl<Type> &inferredReturnTypes) {
1625   if (operands[0].getType().isa<ValueShapeType>())
1626     inferredReturnTypes.assign({ShapeType::get(context)});
1627   else {
1628     auto shapedTy = operands[0].getType().cast<ShapedType>();
1629     int64_t rank =
1630         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1631     Type indexTy = IndexType::get(context);
1632     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1633     inferredReturnTypes.assign({extentTensorTy});
1634   }
1635   return success();
1636 }
1637 
isCompatibleReturnTypes(TypeRange l,TypeRange r)1638 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1639   if (l.size() != 1 || r.size() != 1)
1640     return false;
1641   if (l == r)
1642     return true;
1643 
1644   Type lhs = l.front();
1645   Type rhs = r.front();
1646 
1647   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1648     return false;
1649 
1650   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1651     // Shape type is compatible with all other valid return types.
1652     return true;
1653 
1654   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1655     return true;
1656   return false;
1657 }
1658 
verify()1659 LogicalResult shape::ShapeOfOp::verify() {
1660   return verifyShapeOrExtentTensorOp(*this);
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // SizeToIndexOp
1665 //===----------------------------------------------------------------------===//
1666 
fold(ArrayRef<Attribute> operands)1667 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1668   // Constant values of both types, `shape.size` and `index`, are represented as
1669   // `IntegerAttr`s which makes constant folding simple.
1670   if (Attribute arg = operands[0])
1671     return arg;
1672   return OpFoldResult();
1673 }
1674 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1675 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1676                                                 MLIRContext *context) {
1677   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1678 }
1679 
areCastCompatible(TypeRange inputs,TypeRange outputs)1680 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1681   if (inputs.size() != 1 || outputs.size() != 1)
1682     return false;
1683   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1684 }
1685 
1686 //===----------------------------------------------------------------------===//
1687 // YieldOp
1688 //===----------------------------------------------------------------------===//
1689 
verify()1690 LogicalResult shape::YieldOp::verify() {
1691   auto *parentOp = (*this)->getParentOp();
1692   auto results = parentOp->getResults();
1693   auto operands = getOperands();
1694 
1695   if (parentOp->getNumResults() != getNumOperands())
1696     return emitOpError() << "number of operands does not match number of "
1697                             "results of its parent";
1698   for (auto e : llvm::zip(results, operands))
1699     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1700       return emitOpError() << "types mismatch between yield op and its parent";
1701 
1702   return success();
1703 }
1704 
1705 //===----------------------------------------------------------------------===//
1706 // SplitAtOp
1707 //===----------------------------------------------------------------------===//
1708 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1709 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1710                               SmallVectorImpl<OpFoldResult> &results) {
1711   if (!operands[0] || !operands[1])
1712     return failure();
1713   auto shapeVec = llvm::to_vector<6>(
1714       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1715   auto shape = llvm::makeArrayRef(shapeVec);
1716   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1717   // Verify that the split point is in the correct range.
1718   // TODO: Constant fold to an "error".
1719   int64_t rank = shape.size();
1720   if (-rank > splitPoint || splitPoint > rank)
1721     return failure();
1722   if (splitPoint < 0)
1723     splitPoint += shape.size();
1724   Builder builder(operands[0].getContext());
1725   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1726   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1727   return success();
1728 }
1729 
1730 //===----------------------------------------------------------------------===//
1731 // ToExtentTensorOp
1732 //===----------------------------------------------------------------------===//
1733 
fold(ArrayRef<Attribute> operands)1734 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1735   if (!operands[0])
1736     return OpFoldResult();
1737   Builder builder(getContext());
1738   auto shape = llvm::to_vector<6>(
1739       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1740   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1741                                     builder.getIndexType());
1742   return DenseIntElementsAttr::get(type, shape);
1743 }
1744 
areCastCompatible(TypeRange inputs,TypeRange outputs)1745 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1746   if (inputs.size() != 1 || outputs.size() != 1)
1747     return false;
1748   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1749     if (!inputTensor.getElementType().isa<IndexType>() ||
1750         inputTensor.getRank() != 1)
1751       return false;
1752   } else if (!inputs[0].isa<ShapeType>()) {
1753     return false;
1754   }
1755 
1756   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1757   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1758 }
1759 
1760 //===----------------------------------------------------------------------===//
1761 // ReduceOp
1762 //===----------------------------------------------------------------------===//
1763 
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)1764 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1765                      ValueRange initVals) {
1766   result.addOperands(shape);
1767   result.addOperands(initVals);
1768 
1769   Region *bodyRegion = result.addRegion();
1770   bodyRegion->push_back(new Block);
1771   Block &bodyBlock = bodyRegion->front();
1772   bodyBlock.addArgument(builder.getIndexType(), result.location);
1773 
1774   Type elementType;
1775   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1776     elementType = tensorType.getElementType();
1777   else
1778     elementType = SizeType::get(builder.getContext());
1779   bodyBlock.addArgument(elementType, shape.getLoc());
1780 
1781   for (Value initVal : initVals) {
1782     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1783     result.addTypes(initVal.getType());
1784   }
1785 }
1786 
verify()1787 LogicalResult ReduceOp::verify() {
1788   // Verify block arg types.
1789   Block &block = getRegion().front();
1790 
1791   // The block takes index, extent, and aggregated values as arguments.
1792   auto blockArgsCount = getInitVals().size() + 2;
1793   if (block.getNumArguments() != blockArgsCount)
1794     return emitOpError() << "ReduceOp body is expected to have "
1795                          << blockArgsCount << " arguments";
1796 
1797   // The first block argument is the index and must always be of type `index`.
1798   if (!block.getArgument(0).getType().isa<IndexType>())
1799     return emitOpError(
1800         "argument 0 of ReduceOp body is expected to be of IndexType");
1801 
1802   // The second block argument is the extent and must be of type `size` or
1803   // `index`, depending on whether the reduce operation is applied to a shape or
1804   // to an extent tensor.
1805   Type extentTy = block.getArgument(1).getType();
1806   if (getShape().getType().isa<ShapeType>()) {
1807     if (!extentTy.isa<SizeType>())
1808       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1809                          "SizeType if the ReduceOp operates on a ShapeType");
1810   } else {
1811     if (!extentTy.isa<IndexType>())
1812       return emitOpError(
1813           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1814           "ReduceOp operates on an extent tensor");
1815   }
1816 
1817   for (const auto &type : llvm::enumerate(getInitVals()))
1818     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1819       return emitOpError() << "type mismatch between argument "
1820                            << type.index() + 2
1821                            << " of ReduceOp body and initial value "
1822                            << type.index();
1823   return success();
1824 }
1825 
parse(OpAsmParser & parser,OperationState & result)1826 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1827   // Parse operands.
1828   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1829   Type shapeOrExtentTensorType;
1830   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1831                               OpAsmParser::Delimiter::Paren) ||
1832       parser.parseColonType(shapeOrExtentTensorType) ||
1833       parser.parseOptionalArrowTypeList(result.types))
1834     return failure();
1835 
1836   // Resolve operands.
1837   auto initVals = llvm::makeArrayRef(operands).drop_front();
1838   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1839                             result.operands) ||
1840       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1841                              result.operands))
1842     return failure();
1843 
1844   // Parse the body.
1845   Region *body = result.addRegion();
1846   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1847     return failure();
1848 
1849   // Parse attributes.
1850   if (parser.parseOptionalAttrDict(result.attributes))
1851     return failure();
1852 
1853   return success();
1854 }
1855 
print(OpAsmPrinter & p)1856 void ReduceOp::print(OpAsmPrinter &p) {
1857   p << '(' << getShape() << ", " << getInitVals()
1858     << ") : " << getShape().getType();
1859   p.printOptionalArrowTypeList(getResultTypes());
1860   p << ' ';
1861   p.printRegion(getRegion());
1862   p.printOptionalAttrDict((*this)->getAttrs());
1863 }
1864 
1865 #define GET_OP_CLASSES
1866 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1867 
1868 #define GET_TYPEDEF_CLASSES
1869 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
1870