1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 using namespace mlir;
31 using namespace mlir::shape;
32
33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
34
35 namespace {
36 #include "ShapeCanonicalization.inc"
37 } // namespace
38
getExtentTensorType(MLIRContext * ctx,int64_t rank)39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
40 return RankedTensorType::get({rank}, IndexType::get(ctx));
41 }
42
isExtentTensorType(Type type)43 bool shape::isExtentTensorType(Type type) {
44 auto ranked = type.dyn_cast<RankedTensorType>();
45 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
46 }
47
getShapeVec(Value input,SmallVectorImpl<int64_t> & shapeValues)48 LogicalResult shape::getShapeVec(Value input,
49 SmallVectorImpl<int64_t> &shapeValues) {
50 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
51 auto type = inputOp.getArg().getType().cast<ShapedType>();
52 if (!type.hasRank())
53 return failure();
54 llvm::append_range(shapeValues, type.getShape());
55 return success();
56 }
57 DenseIntElementsAttr attr;
58 if (matchPattern(input, m_Constant(&attr))) {
59 llvm::append_range(shapeValues, attr.getValues<int64_t>());
60 return success();
61 }
62 return failure();
63 }
64
isErrorPropagationPossible(TypeRange operandTypes)65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66 return llvm::any_of(operandTypes, [](Type ty) {
67 return ty.isa<SizeType, ShapeType, ValueShapeType>();
68 });
69 }
70
verifySizeOrIndexOp(Operation * op)71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72 assert(op != nullptr && op->getNumResults() == 1);
73 Type resultTy = op->getResultTypes().front();
74 if (isErrorPropagationPossible(op->getOperandTypes())) {
75 if (!resultTy.isa<SizeType>())
76 return op->emitOpError()
77 << "if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
79 }
80 return success();
81 }
82
verifyShapeOrExtentTensorOp(Operation * op)83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84 assert(op != nullptr && op->getNumResults() == 1);
85 Type resultTy = op->getResultTypes().front();
86 if (isErrorPropagationPossible(op->getOperandTypes())) {
87 if (!resultTy.isa<ShapeType>())
88 return op->emitOpError()
89 << "if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
91 }
92 return success();
93 }
94
95 template <typename... Ty>
eachHasOnlyOneOfTypes(TypeRange typeRange)96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99
100 template <typename... Ty, typename... ranges>
eachHasOnlyOneOfTypes(TypeRange l,ranges...rs)101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
113
114 // Returns true if the given region 'src' can be inlined into the region
115 // 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonde1778d90311::ShapeInlinerInterface116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117 BlockAndValueMapping &) const final {
118 return true;
119 }
120
121 // Returns true if the given operation 'op', that is registered to this
122 // dialect, can be inlined into the region 'dest' that is attached to an
123 // operation registered to the current dialect.
isLegalToInline__anonde1778d90311::ShapeInlinerInterface124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125 BlockAndValueMapping &) const final {
126 return true;
127 }
128 };
129 } // namespace
130
initialize()131 void ShapeDialect::initialize() {
132 addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135 >();
136 addTypes<
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139 >();
140 addInterfaces<ShapeInlinerInterface>();
141 // Allow unknown operations during prototyping and testing. As the dialect is
142 // still evolving it makes it simple to start with an unregistered ops and
143 // try different variants before actually defining the op.
144 allowUnknownOperations();
145 }
146
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
148 Attribute value, Type type,
149 Location loc) {
150 if (type.isa<ShapeType>() || isExtentTensorType(type))
151 return builder.create<ConstShapeOp>(loc, type,
152 value.cast<DenseIntElementsAttr>());
153 if (type.isa<SizeType>())
154 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155 if (type.isa<WitnessType>())
156 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157 if (arith::ConstantOp::isBuildableWith(value, type))
158 return builder.create<arith::ConstantOp>(loc, type, value);
159 return nullptr;
160 }
161
verifyOperationAttribute(Operation * op,NamedAttribute attribute)162 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
163 NamedAttribute attribute) {
164 // Verify shape.lib attribute.
165 if (attribute.getName() == "shape.lib") {
166 if (!op->hasTrait<OpTrait::SymbolTable>())
167 return op->emitError(
168 "shape.lib attribute may only be on op implementing SymbolTable");
169
170 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
171 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
172 if (!symbol)
173 return op->emitError("shape function library ")
174 << symbolRef << " not found";
175 return isa<shape::FunctionLibraryOp>(symbol)
176 ? success()
177 : op->emitError()
178 << symbolRef << " required to be shape function library";
179 }
180
181 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
182 // Verify all entries are function libraries and mappings in libraries
183 // refer to unique ops.
184 DenseSet<StringAttr> key;
185 for (auto it : arr) {
186 if (!it.isa<SymbolRefAttr>())
187 return op->emitError(
188 "only SymbolRefAttr allowed in shape.lib attribute array");
189
190 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
191 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
192 if (!shapeFnLib)
193 return op->emitError()
194 << it << " does not refer to FunctionLibraryOp";
195 for (auto mapping : shapeFnLib.getMapping()) {
196 if (!key.insert(mapping.getName()).second) {
197 return op->emitError("only one op to shape mapping allowed, found "
198 "multiple for `")
199 << mapping.getName() << "`";
200 }
201 }
202 }
203 return success();
204 }
205
206 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
207 "allowed as shape.lib attribute");
208 }
209 return success();
210 }
211
212 //===----------------------------------------------------------------------===//
213 // AnyOp
214 //===----------------------------------------------------------------------===//
215
216 // TODO: Canonicalization should be implemented for shapes that can be
217 // determined through mixtures of the known dimensions of the inputs.
fold(ArrayRef<Attribute> operands)218 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
219 // Only the last operand is checked because AnyOp is commutative.
220 if (operands.back())
221 return operands.back();
222
223 return nullptr;
224 }
225
226 //===----------------------------------------------------------------------===//
227 // AssumingOp
228 //===----------------------------------------------------------------------===//
229
parse(OpAsmParser & parser,OperationState & result)230 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
231 result.regions.reserve(1);
232 Region *doRegion = result.addRegion();
233
234 auto &builder = parser.getBuilder();
235 OpAsmParser::UnresolvedOperand cond;
236 if (parser.parseOperand(cond) ||
237 parser.resolveOperand(cond, builder.getType<WitnessType>(),
238 result.operands))
239 return failure();
240
241 // Parse optional results type list.
242 if (parser.parseOptionalArrowTypeList(result.types))
243 return failure();
244
245 // Parse the region and add a terminator if elided.
246 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
247 return failure();
248 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
249
250 // Parse the optional attribute list.
251 if (parser.parseOptionalAttrDict(result.attributes))
252 return failure();
253 return success();
254 }
255
print(OpAsmPrinter & p)256 void AssumingOp::print(OpAsmPrinter &p) {
257 bool yieldsResults = !getResults().empty();
258
259 p << " " << getWitness();
260 if (yieldsResults)
261 p << " -> (" << getResultTypes() << ")";
262 p << ' ';
263 p.printRegion(getDoRegion(),
264 /*printEntryBlockArgs=*/false,
265 /*printBlockTerminators=*/yieldsResults);
266 p.printOptionalAttrDict((*this)->getAttrs());
267 }
268
269 namespace {
270 // Removes AssumingOp with a passing witness and inlines the region.
271 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
272 using OpRewritePattern<AssumingOp>::OpRewritePattern;
273
matchAndRewrite__anonde1778d90411::AssumingWithTrue274 LogicalResult matchAndRewrite(AssumingOp op,
275 PatternRewriter &rewriter) const override {
276 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
277 if (!witness || !witness.getPassingAttr())
278 return failure();
279
280 AssumingOp::inlineRegionIntoParent(op, rewriter);
281 return success();
282 }
283 };
284
285 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
286 using OpRewritePattern<AssumingOp>::OpRewritePattern;
287
matchAndRewrite__anonde1778d90411::AssumingOpRemoveUnusedResults288 LogicalResult matchAndRewrite(AssumingOp op,
289 PatternRewriter &rewriter) const override {
290 Block *body = op.getBody();
291 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
292
293 // Find used values.
294 SmallVector<Value, 4> newYieldOperands;
295 Value opResult, yieldOperand;
296 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
297 std::tie(opResult, yieldOperand) = it;
298 if (!opResult.getUses().empty()) {
299 newYieldOperands.push_back(yieldOperand);
300 }
301 }
302
303 // Rewrite only if redundant results exist.
304 if (newYieldOperands.size() == yieldOp->getNumOperands())
305 return failure();
306
307 // Replace yield op in the old assuming op's body and move the entire region
308 // to the new assuming op.
309 rewriter.setInsertionPointToEnd(body);
310 auto newYieldOp =
311 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
312 rewriter.setInsertionPoint(op);
313 auto newOp = rewriter.create<AssumingOp>(
314 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
315 newOp.getDoRegion().takeBody(op.getDoRegion());
316
317 // Use the new results to replace the previously used ones.
318 SmallVector<Value, 4> replacementValues;
319 auto src = newOp.getResults().begin();
320 for (auto it : op.getResults()) {
321 if (it.getUses().empty())
322 replacementValues.push_back(nullptr);
323 else
324 replacementValues.push_back(*src++);
325 }
326 rewriter.replaceOp(op, replacementValues);
327 return success();
328 }
329 };
330 } // namespace
331
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)332 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
333 MLIRContext *context) {
334 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
335 }
336
337 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)338 void AssumingOp::getSuccessorRegions(
339 Optional<unsigned> index, ArrayRef<Attribute> operands,
340 SmallVectorImpl<RegionSuccessor> ®ions) {
341 // AssumingOp has unconditional control flow into the region and back to the
342 // parent, so return the correct RegionSuccessor purely based on the index
343 // being None or 0.
344 if (index) {
345 regions.push_back(RegionSuccessor(getResults()));
346 return;
347 }
348
349 regions.push_back(RegionSuccessor(&getDoRegion()));
350 }
351
inlineRegionIntoParent(AssumingOp & op,PatternRewriter & rewriter)352 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
353 PatternRewriter &rewriter) {
354 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
355 auto *assumingBlock = op.getBody();
356 auto initPosition = rewriter.getInsertionPoint();
357 auto *blockAfterAssuming =
358 rewriter.splitBlock(blockBeforeAssuming, initPosition);
359
360 // Remove the AssumingOp and AssumingYieldOp.
361 auto &yieldOp = assumingBlock->back();
362 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
363 rewriter.replaceOp(op, yieldOp.getOperands());
364 rewriter.eraseOp(&yieldOp);
365
366 // Merge blocks together as there was no branching behavior from the
367 // AssumingOp.
368 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
369 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
370 }
371
build(OpBuilder & builder,OperationState & result,Value witness,function_ref<SmallVector<Value,2> (OpBuilder &,Location)> bodyBuilder)372 void AssumingOp::build(
373 OpBuilder &builder, OperationState &result, Value witness,
374 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
375
376 result.addOperands(witness);
377 Region *bodyRegion = result.addRegion();
378 bodyRegion->push_back(new Block);
379 Block &bodyBlock = bodyRegion->front();
380
381 // Build body.
382 OpBuilder::InsertionGuard guard(builder);
383 builder.setInsertionPointToStart(&bodyBlock);
384 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
385 builder.create<AssumingYieldOp>(result.location, yieldValues);
386
387 SmallVector<Type, 2> assumingTypes;
388 for (Value v : yieldValues)
389 assumingTypes.push_back(v.getType());
390 result.addTypes(assumingTypes);
391 }
392
393 //===----------------------------------------------------------------------===//
394 // AddOp
395 //===----------------------------------------------------------------------===//
396
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)397 LogicalResult mlir::shape::AddOp::inferReturnTypes(
398 MLIRContext *context, Optional<Location> location, ValueRange operands,
399 DictionaryAttr attributes, RegionRange regions,
400 SmallVectorImpl<Type> &inferredReturnTypes) {
401 if (operands[0].getType().isa<SizeType>() ||
402 operands[1].getType().isa<SizeType>())
403 inferredReturnTypes.assign({SizeType::get(context)});
404 else
405 inferredReturnTypes.assign({IndexType::get(context)});
406 return success();
407 }
408
isCompatibleReturnTypes(TypeRange l,TypeRange r)409 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
410 // SizeType is compatible with IndexType.
411 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
412 }
413
fold(ArrayRef<Attribute> operands)414 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
415 // add(x, 0) -> x
416 if (matchPattern(getRhs(), m_Zero()))
417 return getLhs();
418
419 return constFoldBinaryOp<IntegerAttr>(
420 operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
421 }
422
verify()423 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
424
425 //===----------------------------------------------------------------------===//
426 // AssumingAllOp
427 //===----------------------------------------------------------------------===//
428
429 namespace {
430
431 // Merge multiple `shape.assuming_all` operations together.
432 //
433 // %0 = shape.assuming_all %w0, %w1
434 // %1 = shape.assuming_all %w2, %0
435 //
436 // to:
437 //
438 // %0 = shape.assuming_all %w0, %w2, %w2
439 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
440 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
441
matchAndRewrite__anonde1778d90611::MergeAssumingAllOps442 LogicalResult matchAndRewrite(AssumingAllOp op,
443 PatternRewriter &rewriter) const override {
444 SmallVector<Value> operands;
445
446 for (Value operand : op.getInputs()) {
447 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
448 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
449 else
450 operands.push_back(operand);
451 }
452
453 // We didn't find any other `assuming_all` ops to merge with.
454 if (operands.size() == op.getNumOperands())
455 return failure();
456
457 // Replace with a new `assuming_all` operation with merged constraints.
458 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
459 return success();
460 }
461 };
462
463 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
464 // are subsumed by others.
465 //
466 // %0 = shape.cstr_broadcastable %shape0, %shape1
467 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
468 //
469 // %2 = shape.cstr_broadcastable %shape3, %shape4
470 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
471 //
472 // %4 = shape.assuming_all %0, %1, %2, %3
473 //
474 // to:
475 //
476 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
477 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
478 // %2 = shape.assuming_all %0, %1
479 //
480 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
481 // shapes [0, 1] are broadcastable too, and can be removed from the list of
482 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
483 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
484 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
485 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
486
matchAndRewrite__anonde1778d90611::AssumingAllOfCstrBroadcastable487 LogicalResult matchAndRewrite(AssumingAllOp op,
488 PatternRewriter &rewriter) const override {
489 // Collect all `CstrBroadcastableOp` operands first.
490 SetVector<CstrBroadcastableOp> operands;
491 for (Value operand : op.getInputs()) {
492 // TODO: Apply this optimization if some of the witnesses are not
493 // produced by the `cstr_broadcastable`.
494 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
495 if (!broadcastable)
496 return failure();
497
498 operands.insert(broadcastable);
499 }
500
501 // Skip trivial `assuming_all` operations.
502 if (operands.size() <= 1)
503 return failure();
504
505 // Collect shapes checked by `cstr_broadcastable` operands.
506 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
507 for (auto cstr : operands) {
508 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
509 shapes.emplace_back(cstr, std::move(shapesSet));
510 }
511
512 // Sort by the number of shape operands (larger to smaller).
513 llvm::sort(shapes, [](auto a, auto b) {
514 return a.first.getNumOperands() > b.first.getNumOperands();
515 });
516
517 // We start from the `cst_broadcastable` operations with largest number of
518 // shape operands, and remove redundant `cst_broadcastable` operations. We
519 // do this until we find a set of `cst_broadcastable` operations with
520 // non-overlapping constraints.
521 SmallVector<CstrBroadcastableOp> markedForErase;
522
523 for (unsigned i = 0; i < shapes.size(); ++i) {
524 auto isSubset = [&](auto pair) {
525 return llvm::set_is_subset(pair.second, shapes[i].second);
526 };
527
528 // Keep redundant `cstr_broadcastable` operations to be erased.
529 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
530 for (auto *it0 = it; it0 < shapes.end(); ++it0)
531 markedForErase.push_back(it0->first);
532 shapes.erase(it, shapes.end());
533 }
534
535 // We didn't find any operands that could be removed.
536 if (markedForErase.empty())
537 return failure();
538
539 // Collect non-overlapping `cst_broadcastable` constraints.
540 SmallVector<Value> uniqueConstraints;
541 for (auto &shape : shapes)
542 uniqueConstraints.push_back(shape.first.getResult());
543
544 // Replace with a new `assuming_all` operation ...
545 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
546
547 // ... and maybe erase `cstr_broadcastable` ops without uses.
548 for (auto &op : markedForErase)
549 if (op->use_empty())
550 rewriter.eraseOp(op);
551
552 return success();
553 }
554 };
555
556 struct AssumingAllToCstrEqCanonicalization
557 : public OpRewritePattern<AssumingAllOp> {
558 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
559
matchAndRewrite__anonde1778d90611::AssumingAllToCstrEqCanonicalization560 LogicalResult matchAndRewrite(AssumingAllOp op,
561 PatternRewriter &rewriter) const override {
562 SmallVector<Value, 8> shapes;
563 for (Value w : op.getInputs()) {
564 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
565 if (!cstrEqOp)
566 return failure();
567 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
568 return llvm::is_contained(shapes, s);
569 });
570 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
571 return failure();
572 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
573 }
574 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
575 return success();
576 }
577 };
578
579 template <typename OpTy>
580 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
581 using OpRewritePattern<OpTy>::OpRewritePattern;
582
matchAndRewrite__anonde1778d90611::RemoveDuplicateOperandsPattern583 LogicalResult matchAndRewrite(OpTy op,
584 PatternRewriter &rewriter) const override {
585 // Find unique operands.
586 SetVector<Value> unique(op.operand_begin(), op.operand_end());
587
588 // Reduce op to equivalent with unique operands.
589 if (unique.size() < op.getNumOperands()) {
590 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
591 unique.takeVector(), op->getAttrs());
592 return success();
593 }
594
595 return failure();
596 }
597 };
598 } // namespace
599
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)600 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
601 MLIRContext *context) {
602 patterns
603 .add<MergeAssumingAllOps, AssumingAllOneOp,
604 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
605 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
606 }
607
fold(ArrayRef<Attribute> operands)608 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
609 // Iterate in reverse to first handle all constant operands. They are
610 // guaranteed to be the tail of the inputs because this is commutative.
611 for (int idx = operands.size() - 1; idx >= 0; idx--) {
612 Attribute a = operands[idx];
613 // Cannot fold if any inputs are not constant;
614 if (!a)
615 return nullptr;
616
617 // We do not need to keep statically known values after handling them in
618 // this method.
619 getOperation()->eraseOperand(idx);
620
621 // Always false if any input is statically known false
622 if (!a.cast<BoolAttr>().getValue())
623 return a;
624 }
625 // If this is reached, all inputs were statically known passing.
626 return BoolAttr::get(getContext(), true);
627 }
628
verify()629 LogicalResult AssumingAllOp::verify() {
630 // Ensure that AssumingAllOp contains at least one operand
631 if (getNumOperands() == 0)
632 return emitOpError("no operands specified");
633
634 return success();
635 }
636
637 //===----------------------------------------------------------------------===//
638 // BroadcastOp
639 //===----------------------------------------------------------------------===//
640
fold(ArrayRef<Attribute> operands)641 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
642 if (getShapes().size() == 1) {
643 // Otherwise, we need a cast which would be a canonicalization, not folding.
644 if (getShapes().front().getType() != getType())
645 return nullptr;
646 return getShapes().front();
647 }
648
649 // TODO: Support folding with more than 2 input shapes
650 if (getShapes().size() > 2)
651 return nullptr;
652
653 if (!operands[0] || !operands[1])
654 return nullptr;
655 auto lhsShape = llvm::to_vector<6>(
656 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
657 auto rhsShape = llvm::to_vector<6>(
658 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
659 SmallVector<int64_t, 6> resultShape;
660
661 // If the shapes are not compatible, we can't fold it.
662 // TODO: Fold to an "error".
663 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
664 return nullptr;
665
666 Builder builder(getContext());
667 return builder.getIndexTensorAttr(resultShape);
668 }
669
verify()670 LogicalResult BroadcastOp::verify() {
671 return verifyShapeOrExtentTensorOp(*this);
672 }
673
674 namespace {
675 template <typename OpTy>
676 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
677 using OpRewritePattern<OpTy>::OpRewritePattern;
678
matchAndRewrite__anonde1778d90a11::RemoveEmptyShapeOperandsPattern679 LogicalResult matchAndRewrite(OpTy op,
680 PatternRewriter &rewriter) const override {
681 auto isPotentiallyNonEmptyShape = [](Value shape) {
682 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683 if (extentTensorTy.getDimSize(0) == 0)
684 return false;
685 }
686 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
687 if (constShape.getShape().empty())
688 return false;
689 }
690 return true;
691 };
692 auto newOperands = llvm::to_vector<8>(
693 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
694
695 // Reduce op to equivalent without empty shape operands.
696 if (newOperands.size() < op.getNumOperands()) {
697 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
698 op->getAttrs());
699 return success();
700 }
701
702 return failure();
703 }
704 };
705
706 struct BroadcastForwardSingleOperandPattern
707 : public OpRewritePattern<BroadcastOp> {
708 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
709
matchAndRewrite__anonde1778d90a11::BroadcastForwardSingleOperandPattern710 LogicalResult matchAndRewrite(BroadcastOp op,
711 PatternRewriter &rewriter) const override {
712 if (op.getNumOperands() != 1)
713 return failure();
714 Value replacement = op.getShapes().front();
715
716 // Insert cast if needed.
717 if (replacement.getType() != op.getType()) {
718 auto loc = op.getLoc();
719 if (op.getType().isa<ShapeType>()) {
720 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
721 } else {
722 assert(!op.getType().isa<ShapeType>() &&
723 !replacement.getType().isa<ShapeType>() &&
724 "expect extent tensor cast");
725 replacement =
726 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
727 }
728 }
729
730 rewriter.replaceOp(op, replacement);
731 return success();
732 }
733 };
734
735 struct BroadcastFoldConstantOperandsPattern
736 : public OpRewritePattern<BroadcastOp> {
737 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
738
matchAndRewrite__anonde1778d90a11::BroadcastFoldConstantOperandsPattern739 LogicalResult matchAndRewrite(BroadcastOp op,
740 PatternRewriter &rewriter) const override {
741 SmallVector<int64_t, 8> foldedConstantShape;
742 SmallVector<Value, 8> newShapeOperands;
743 for (Value shape : op.getShapes()) {
744 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
745 SmallVector<int64_t, 8> newFoldedConstantShape;
746 if (OpTrait::util::getBroadcastedShape(
747 foldedConstantShape,
748 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
749 newFoldedConstantShape)) {
750 foldedConstantShape = newFoldedConstantShape;
751 continue;
752 }
753 }
754 newShapeOperands.push_back(shape);
755 }
756
757 // Need at least two constant operands to fold anything.
758 if (op.getNumOperands() - newShapeOperands.size() < 2)
759 return failure();
760
761 auto foldedConstantOperandsTy = RankedTensorType::get(
762 {static_cast<int64_t>(foldedConstantShape.size())},
763 rewriter.getIndexType());
764 newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
765 op.getLoc(), foldedConstantOperandsTy,
766 rewriter.getIndexTensorAttr(foldedConstantShape)));
767 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
768 newShapeOperands);
769 return success();
770 }
771 };
772
773 template <typename OpTy>
774 struct CanonicalizeCastExtentTensorOperandsPattern
775 : public OpRewritePattern<OpTy> {
776 using OpRewritePattern<OpTy>::OpRewritePattern;
777
matchAndRewrite__anonde1778d90a11::CanonicalizeCastExtentTensorOperandsPattern778 LogicalResult matchAndRewrite(OpTy op,
779 PatternRewriter &rewriter) const override {
780 // Canonicalize operands.
781 bool anyChange = false;
782 auto canonicalizeOperand = [&](Value operand) {
783 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
784 // Only eliminate the cast if it holds no shape information.
785 bool isInformationLoosingCast =
786 castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
787 if (isInformationLoosingCast) {
788 anyChange = true;
789 return castOp.getSource();
790 }
791 }
792 return operand;
793 };
794 auto newOperands = llvm::to_vector<8>(
795 llvm::map_range(op.getOperands(), canonicalizeOperand));
796
797 // Rewrite op if any change required.
798 if (!anyChange)
799 return failure();
800 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
801 return success();
802 }
803 };
804
805 struct BroadcastConcretizeResultTypePattern
806 : public OpRewritePattern<BroadcastOp> {
807 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
808
matchAndRewrite__anonde1778d90a11::BroadcastConcretizeResultTypePattern809 LogicalResult matchAndRewrite(BroadcastOp op,
810 PatternRewriter &rewriter) const override {
811 // Only concretize dynamic extent tensor result types.
812 auto resultTy = op.getType().dyn_cast<RankedTensorType>();
813 if (!resultTy || !resultTy.isDynamicDim(0))
814 return failure();
815
816 // Infer resulting shape rank if possible.
817 int64_t maxRank = 0;
818 for (Value shape : op.getShapes()) {
819 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
820 // Cannot infer resulting shape rank if any operand is dynamically
821 // ranked.
822 if (extentTensorTy.isDynamicDim(0))
823 return failure();
824 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
825 }
826 }
827
828 auto newOp = rewriter.create<BroadcastOp>(
829 op.getLoc(), getExtentTensorType(getContext(), maxRank),
830 op.getShapes());
831 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
832 return success();
833 }
834 };
835 } // namespace
836
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)837 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
838 MLIRContext *context) {
839 patterns.add<BroadcastConcretizeResultTypePattern,
840 BroadcastFoldConstantOperandsPattern,
841 BroadcastForwardSingleOperandPattern,
842 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
843 RemoveDuplicateOperandsPattern<BroadcastOp>,
844 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
845 }
846
847 //===----------------------------------------------------------------------===//
848 // ConcatOp
849 //===----------------------------------------------------------------------===//
850
fold(ArrayRef<Attribute> operands)851 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
852 if (!operands[0] || !operands[1])
853 return nullptr;
854 auto lhsShape = llvm::to_vector<6>(
855 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
856 auto rhsShape = llvm::to_vector<6>(
857 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
858 SmallVector<int64_t, 6> resultShape;
859 resultShape.append(lhsShape.begin(), lhsShape.end());
860 resultShape.append(rhsShape.begin(), rhsShape.end());
861 Builder builder(getContext());
862 return builder.getIndexTensorAttr(resultShape);
863 }
864
865 //===----------------------------------------------------------------------===//
866 // ConstShapeOp
867 //===----------------------------------------------------------------------===//
868
print(OpAsmPrinter & p)869 void ConstShapeOp::print(OpAsmPrinter &p) {
870 p << " ";
871 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
872 p << "[";
873 interleaveComma(getShape().getValues<int64_t>(), p);
874 p << "] : ";
875 p.printType(getType());
876 }
877
parse(OpAsmParser & parser,OperationState & result)878 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
879 if (parser.parseOptionalAttrDict(result.attributes))
880 return failure();
881 // We piggy-back on ArrayAttr parsing, though we don't internally store the
882 // shape as an ArrayAttr.
883 // TODO: Implement custom parser and maybe make syntax a bit more concise.
884 Attribute extentsRaw;
885 NamedAttrList dummy;
886 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
887 return failure();
888 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
889 if (!extentsArray)
890 return failure();
891 SmallVector<int64_t, 6> ints;
892 for (Attribute extent : extentsArray) {
893 IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
894 if (!attr)
895 return failure();
896 ints.push_back(attr.getInt());
897 }
898 Builder &builder = parser.getBuilder();
899 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
900 Type resultTy;
901 if (parser.parseColonType(resultTy))
902 return failure();
903 result.types.push_back(resultTy);
904 return success();
905 }
906
fold(ArrayRef<Attribute>)907 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
908
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)909 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
910 MLIRContext *context) {
911 patterns.add<TensorCastConstShape>(context);
912 }
913
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)914 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
915 MLIRContext *context, Optional<Location> location, ValueRange operands,
916 DictionaryAttr attributes, RegionRange regions,
917 SmallVectorImpl<Type> &inferredReturnTypes) {
918 Builder b(context);
919 auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
920 if (!shape)
921 return emitOptionalError(location, "missing shape attribute");
922 inferredReturnTypes.assign({RankedTensorType::get(
923 {static_cast<int64_t>(shape.size())}, b.getIndexType())});
924 return success();
925 }
926
isCompatibleReturnTypes(TypeRange l,TypeRange r)927 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
928 TypeRange r) {
929 if (l.size() != 1 || r.size() != 1)
930 return false;
931
932 Type lhs = l.front();
933 Type rhs = r.front();
934
935 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
936 // Shape type is compatible with all other valid return types.
937 return true;
938 return lhs == rhs;
939 }
940
941 //===----------------------------------------------------------------------===//
942 // CstrBroadcastableOp
943 //===----------------------------------------------------------------------===//
944
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)945 void CstrBroadcastableOp::getCanonicalizationPatterns(
946 RewritePatternSet &patterns, MLIRContext *context) {
947 // Canonicalization patterns have overlap with the considerations during
948 // folding in case additional shape information is inferred at some point that
949 // does not result in folding.
950 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
951 CstrBroadcastableEqOps,
952 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
953 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
954 }
955
956 // Return true if there is exactly one attribute not representing a scalar
957 // broadcast.
hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes)958 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
959 bool nonScalarSeen = false;
960 for (Attribute a : attributes) {
961 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
962 if (nonScalarSeen)
963 return false;
964 nonScalarSeen = true;
965 }
966 }
967 return true;
968 }
969
fold(ArrayRef<Attribute> operands)970 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
971 // No broadcasting is needed if all operands but one are scalar.
972 if (hasAtMostSingleNonScalar(operands))
973 return BoolAttr::get(getContext(), true);
974
975 if ([&] {
976 SmallVector<SmallVector<int64_t, 6>, 6> extents;
977 for (const auto &operand : operands) {
978 if (!operand)
979 return false;
980 extents.push_back(llvm::to_vector<6>(
981 operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
982 }
983 return OpTrait::util::staticallyKnownBroadcastable(extents);
984 }())
985 return BoolAttr::get(getContext(), true);
986
987 // Lastly, see if folding can be completed based on what constraints are known
988 // on the input shapes.
989 if ([&] {
990 SmallVector<SmallVector<int64_t, 6>, 6> extents;
991 for (auto shapeValue : getShapes()) {
992 extents.emplace_back();
993 if (failed(getShapeVec(shapeValue, extents.back())))
994 return false;
995 }
996 return OpTrait::util::staticallyKnownBroadcastable(extents);
997 }())
998 return BoolAttr::get(getContext(), true);
999
1000 // Because a failing witness result here represents an eventual assertion
1001 // failure, we do not replace it with a constant witness.
1002 return nullptr;
1003 }
1004
verify()1005 LogicalResult CstrBroadcastableOp::verify() {
1006 // Ensure that CstrBroadcastableOp contains at least two operands
1007 if (getNumOperands() < 2)
1008 return emitOpError("required at least 2 input shapes");
1009 return success();
1010 }
1011
1012 //===----------------------------------------------------------------------===//
1013 // CstrEqOp
1014 //===----------------------------------------------------------------------===//
1015
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1016 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1017 MLIRContext *context) {
1018 // If inputs are equal, return passing witness
1019 patterns.add<CstrEqEqOps>(context);
1020 }
1021
fold(ArrayRef<Attribute> operands)1022 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1023 if (llvm::all_of(operands,
1024 [&](Attribute a) { return a && a == operands[0]; }))
1025 return BoolAttr::get(getContext(), true);
1026
1027 // Because a failing witness result here represents an eventual assertion
1028 // failure, we do not try to replace it with a constant witness. Similarly, we
1029 // cannot if there are any non-const inputs.
1030 return nullptr;
1031 }
1032
1033 //===----------------------------------------------------------------------===//
1034 // ConstSizeOp
1035 //===----------------------------------------------------------------------===//
1036
build(OpBuilder & builder,OperationState & result,int64_t value)1037 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1038 int64_t value) {
1039 build(builder, result, builder.getIndexAttr(value));
1040 }
1041
fold(ArrayRef<Attribute>)1042 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1043
getAsmResultNames(llvm::function_ref<void (Value,StringRef)> setNameFn)1044 void ConstSizeOp::getAsmResultNames(
1045 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1046 SmallString<4> buffer;
1047 llvm::raw_svector_ostream os(buffer);
1048 os << "c" << getValue();
1049 setNameFn(getResult(), os.str());
1050 }
1051
1052 //===----------------------------------------------------------------------===//
1053 // ConstWitnessOp
1054 //===----------------------------------------------------------------------===//
1055
fold(ArrayRef<Attribute>)1056 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1057 return getPassingAttr();
1058 }
1059
1060 //===----------------------------------------------------------------------===//
1061 // CstrRequireOp
1062 //===----------------------------------------------------------------------===//
1063
fold(ArrayRef<Attribute> operands)1064 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1065 return operands[0];
1066 }
1067
1068 //===----------------------------------------------------------------------===//
1069 // DivOp
1070 //===----------------------------------------------------------------------===//
1071
fold(ArrayRef<Attribute> operands)1072 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1073 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1074 if (!lhs)
1075 return nullptr;
1076 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1077 if (!rhs)
1078 return nullptr;
1079
1080 // Division in APInt does not follow floor(lhs, rhs) when the result is
1081 // negative. Rather, APInt rounds toward zero.
1082 APInt quotient, remainder;
1083 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1084 if (quotient.isNegative() && !remainder.isNullValue()) {
1085 quotient -= 1;
1086 }
1087
1088 Type indexTy = IndexType::get(getContext());
1089 return IntegerAttr::get(indexTy, quotient);
1090 }
1091
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1092 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1093 MLIRContext *context, Optional<Location> location, ValueRange operands,
1094 DictionaryAttr attributes, RegionRange regions,
1095 SmallVectorImpl<Type> &inferredReturnTypes) {
1096 if (operands[0].getType().isa<SizeType>() ||
1097 operands[1].getType().isa<SizeType>())
1098 inferredReturnTypes.assign({SizeType::get(context)});
1099 else
1100 inferredReturnTypes.assign({IndexType::get(context)});
1101 return success();
1102 }
1103
isCompatibleReturnTypes(TypeRange l,TypeRange r)1104 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1105 // SizeType is compatible with IndexType.
1106 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1107 }
1108
verify()1109 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1110
1111 //===----------------------------------------------------------------------===//
1112 // ShapeEqOp
1113 //===----------------------------------------------------------------------===//
1114
fold(ArrayRef<Attribute> operands)1115 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1116 bool allSame = true;
1117 if (!operands.empty() && !operands[0])
1118 return {};
1119 for (Attribute operand : operands.drop_front(1)) {
1120 if (!operand)
1121 return {};
1122 allSame = allSame && operand == operands[0];
1123 }
1124 return BoolAttr::get(getContext(), allSame);
1125 }
1126
1127 //===----------------------------------------------------------------------===//
1128 // IndexToSizeOp
1129 //===----------------------------------------------------------------------===//
1130
fold(ArrayRef<Attribute> operands)1131 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1132 // Constant values of both types, `shape.size` and `index`, are represented as
1133 // `IntegerAttr`s which makes constant folding simple.
1134 if (Attribute arg = operands[0])
1135 return arg;
1136 return {};
1137 }
1138
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1139 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1140 MLIRContext *context) {
1141 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1142 }
1143
1144 //===----------------------------------------------------------------------===//
1145 // FromExtentsOp
1146 //===----------------------------------------------------------------------===//
1147
fold(ArrayRef<Attribute> operands)1148 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1149 if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1150 return nullptr;
1151 SmallVector<int64_t, 6> extents;
1152 for (auto attr : operands)
1153 extents.push_back(attr.cast<IntegerAttr>().getInt());
1154 Builder builder(getContext());
1155 return builder.getIndexTensorAttr(extents);
1156 }
1157
1158 //===----------------------------------------------------------------------===//
1159 // FunctionLibraryOp
1160 //===----------------------------------------------------------------------===//
1161
build(OpBuilder & builder,OperationState & result,StringRef name)1162 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1163 StringRef name) {
1164 result.attributes.push_back(builder.getNamedAttr(
1165 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1166 }
1167
getShapeFunction(Operation * op)1168 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1169 auto attr = getMapping()
1170 .get(op->getName().getIdentifier())
1171 .dyn_cast_or_null<FlatSymbolRefAttr>();
1172 if (!attr)
1173 return nullptr;
1174 return lookupSymbol<FuncOp>(attr);
1175 }
1176
parse(OpAsmParser & parser,OperationState & result)1177 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1178 OperationState &result) {
1179 // Parse the op name.
1180 StringAttr nameAttr;
1181 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1182 result.attributes))
1183 return failure();
1184
1185 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1186 return failure();
1187
1188 auto *bodyRegion = result.addRegion();
1189 if (parser.parseRegion(*bodyRegion))
1190 return failure();
1191
1192 if (parser.parseKeyword("mapping"))
1193 return failure();
1194
1195 DictionaryAttr mappingAttr;
1196 if (parser.parseAttribute(mappingAttr,
1197 parser.getBuilder().getType<NoneType>(), "mapping",
1198 result.attributes))
1199 return failure();
1200 return success();
1201 }
1202
print(OpAsmPrinter & p)1203 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1204 p << ' ';
1205 p.printSymbolName(getName());
1206 p.printOptionalAttrDictWithKeyword(
1207 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1208 p << ' ';
1209 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1210 /*printBlockTerminators=*/false);
1211 p << " mapping ";
1212 p.printAttributeWithoutType(getMappingAttr());
1213 }
1214
1215 //===----------------------------------------------------------------------===//
1216 // FuncOp
1217 //===----------------------------------------------------------------------===//
1218
parse(OpAsmParser & parser,OperationState & result)1219 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1220 auto buildFuncType =
1221 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1222 function_interface_impl::VariadicFlag,
1223 std::string &) { return builder.getFunctionType(argTypes, results); };
1224
1225 return function_interface_impl::parseFunctionOp(
1226 parser, result, /*allowVariadic=*/false, buildFuncType);
1227 }
1228
print(OpAsmPrinter & p)1229 void FuncOp::print(OpAsmPrinter &p) {
1230 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
1231 }
1232
1233 //===----------------------------------------------------------------------===//
1234 // GetExtentOp
1235 //===----------------------------------------------------------------------===//
1236
getConstantDim()1237 Optional<int64_t> GetExtentOp::getConstantDim() {
1238 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1239 return constSizeOp.getValue().getLimitedValue();
1240 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1241 return constantOp.getValue().cast<IntegerAttr>().getInt();
1242 return llvm::None;
1243 }
1244
fold(ArrayRef<Attribute> operands)1245 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1246 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1247 if (!elements)
1248 return nullptr;
1249 Optional<int64_t> dim = getConstantDim();
1250 if (!dim.has_value())
1251 return nullptr;
1252 if (dim.value() >= elements.getNumElements())
1253 return nullptr;
1254 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1255 }
1256
build(OpBuilder & builder,OperationState & result,Value shape,int64_t dim)1257 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1258 int64_t dim) {
1259 auto loc = result.location;
1260 auto dimAttr = builder.getIndexAttr(dim);
1261 if (shape.getType().isa<ShapeType>()) {
1262 Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1263 build(builder, result, builder.getType<SizeType>(), shape, dim);
1264 } else {
1265 Value dim =
1266 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1267 build(builder, result, builder.getIndexType(), shape, dim);
1268 }
1269 }
1270
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1271 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1272 MLIRContext *context, Optional<Location> location, ValueRange operands,
1273 DictionaryAttr attributes, RegionRange regions,
1274 SmallVectorImpl<Type> &inferredReturnTypes) {
1275 inferredReturnTypes.assign({IndexType::get(context)});
1276 return success();
1277 }
1278
isCompatibleReturnTypes(TypeRange l,TypeRange r)1279 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1280 TypeRange r) {
1281 // SizeType is compatible with IndexType.
1282 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1283 }
1284
verify()1285 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1286
1287 //===----------------------------------------------------------------------===//
1288 // IsBroadcastableOp
1289 //===----------------------------------------------------------------------===//
1290
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1291 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1292 MLIRContext *context) {
1293 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1294 }
1295
fold(ArrayRef<Attribute> operands)1296 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1297 // Can always broadcast fewer than two shapes.
1298 if (operands.size() < 2) {
1299 return BoolAttr::get(getContext(), true);
1300 }
1301
1302 return nullptr;
1303 }
1304
1305 //===----------------------------------------------------------------------===//
1306 // MeetOp
1307 //===----------------------------------------------------------------------===//
1308
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1309 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1310 MLIRContext *context, Optional<Location> location, ValueRange operands,
1311 DictionaryAttr attributes, RegionRange regions,
1312 SmallVectorImpl<Type> &inferredReturnTypes) {
1313 inferredReturnTypes.assign({operands[0].getType()});
1314 return success();
1315 }
1316
isCompatibleReturnTypes(TypeRange l,TypeRange r)1317 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1318 if (l.size() != 1 || r.size() != 1)
1319 return false;
1320 if (l == r)
1321 return true;
1322
1323 Type lhs = l.front();
1324 Type rhs = r.front();
1325
1326 if (lhs != rhs)
1327 return false;
1328
1329 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1330 return true;
1331
1332 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1333 return true;
1334 return false;
1335 }
1336
1337 //===----------------------------------------------------------------------===//
1338 // RankOp
1339 //===----------------------------------------------------------------------===//
1340
fold(ArrayRef<Attribute> operands)1341 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1342 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1343 if (!shape)
1344 return {};
1345 int64_t rank = shape.getNumElements();
1346 Builder builder(getContext());
1347 return builder.getIndexAttr(rank);
1348 }
1349
1350 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1351 /// Constant folding fails in cases where only the rank is constant, not the
1352 /// shape itself.
1353 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1354 ///
1355 /// Example:
1356 ///
1357 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1358 /// %rank = shape.rank %shape
1359 ///
1360 /// becomes
1361 ///
1362 /// %rank = shape.const_size 3
1363
1364 namespace {
1365 struct RankShapeOfCanonicalizationPattern
1366 : public OpRewritePattern<shape::RankOp> {
1367 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1368
matchAndRewrite__anonde1778d91211::RankShapeOfCanonicalizationPattern1369 LogicalResult matchAndRewrite(shape::RankOp op,
1370 PatternRewriter &rewriter) const override {
1371 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1372 if (!shapeOfOp)
1373 return failure();
1374 auto rankedTensorType =
1375 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1376 if (!rankedTensorType)
1377 return failure();
1378 int64_t rank = rankedTensorType.getRank();
1379 if (op.getType().isa<IndexType>()) {
1380 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1381 rank);
1382 } else if (op.getType().isa<shape::SizeType>()) {
1383 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1384 } else {
1385 return failure();
1386 }
1387 return success();
1388 }
1389 };
1390 } // namespace
1391
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1392 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1393 MLIRContext *context) {
1394 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1395 }
1396
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1397 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1398 MLIRContext *context, Optional<Location> location, ValueRange operands,
1399 DictionaryAttr attributes, RegionRange regions,
1400 SmallVectorImpl<Type> &inferredReturnTypes) {
1401 if (operands[0].getType().isa<ShapeType>())
1402 inferredReturnTypes.assign({SizeType::get(context)});
1403 else
1404 inferredReturnTypes.assign({IndexType::get(context)});
1405 return success();
1406 }
1407
isCompatibleReturnTypes(TypeRange l,TypeRange r)1408 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1409 // SizeType is compatible with IndexType.
1410 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1411 }
1412
verify()1413 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1414
1415 //===----------------------------------------------------------------------===//
1416 // NumElementsOp
1417 //===----------------------------------------------------------------------===//
1418
fold(ArrayRef<Attribute> operands)1419 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1420
1421 // Fold only when argument constant.
1422 Attribute shape = operands[0];
1423 if (!shape)
1424 return {};
1425
1426 APInt product(64, 1);
1427 for (auto value : shape.cast<DenseIntElementsAttr>())
1428 product *= value;
1429 Builder builder(getContext());
1430 return builder.getIndexAttr(product.getLimitedValue());
1431 }
1432
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1433 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1434 MLIRContext *context, Optional<Location> location, ValueRange operands,
1435 DictionaryAttr attributes, RegionRange regions,
1436 SmallVectorImpl<Type> &inferredReturnTypes) {
1437 if (operands[0].getType().isa<ShapeType>())
1438 inferredReturnTypes.assign({SizeType::get(context)});
1439 else
1440 inferredReturnTypes.assign({IndexType::get(context)});
1441 return success();
1442 }
1443
isCompatibleReturnTypes(TypeRange l,TypeRange r)1444 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1445 TypeRange r) {
1446 // SizeType is compatible with IndexType.
1447 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1448 }
1449
verify()1450 LogicalResult shape::NumElementsOp::verify() {
1451 return verifySizeOrIndexOp(*this);
1452 }
1453
1454 //===----------------------------------------------------------------------===//
1455 // MaxOp
1456 //===----------------------------------------------------------------------===//
1457
fold(llvm::ArrayRef<mlir::Attribute> operands)1458 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1459 // If operands are equal, just propagate one.
1460 if (getLhs() == getRhs())
1461 return getLhs();
1462 return nullptr;
1463 }
1464
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1465 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1466 MLIRContext *context, Optional<Location> location, ValueRange operands,
1467 DictionaryAttr attributes, RegionRange regions,
1468 SmallVectorImpl<Type> &inferredReturnTypes) {
1469 if (operands[0].getType() == operands[1].getType())
1470 inferredReturnTypes.assign({operands[0].getType()});
1471 else
1472 inferredReturnTypes.assign({SizeType::get(context)});
1473 return success();
1474 }
1475
isCompatibleReturnTypes(TypeRange l,TypeRange r)1476 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1477 if (l.size() != 1 || r.size() != 1)
1478 return false;
1479 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1480 return true;
1481 if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1482 return true;
1483 return false;
1484 }
1485
1486 //===----------------------------------------------------------------------===//
1487 // MinOp
1488 //===----------------------------------------------------------------------===//
1489
fold(llvm::ArrayRef<mlir::Attribute> operands)1490 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1491 // If operands are equal, just propagate one.
1492 if (getLhs() == getRhs())
1493 return getLhs();
1494 return nullptr;
1495 }
1496
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1497 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1498 MLIRContext *context, Optional<Location> location, ValueRange operands,
1499 DictionaryAttr attributes, RegionRange regions,
1500 SmallVectorImpl<Type> &inferredReturnTypes) {
1501 if (operands[0].getType() == operands[1].getType())
1502 inferredReturnTypes.assign({operands[0].getType()});
1503 else
1504 inferredReturnTypes.assign({SizeType::get(context)});
1505 return success();
1506 }
1507
isCompatibleReturnTypes(TypeRange l,TypeRange r)1508 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1509 if (l.size() != 1 || r.size() != 1)
1510 return false;
1511 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1512 return true;
1513 if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1514 return true;
1515 return false;
1516 }
1517
1518 //===----------------------------------------------------------------------===//
1519 // MulOp
1520 //===----------------------------------------------------------------------===//
1521
fold(ArrayRef<Attribute> operands)1522 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1523 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1524 if (!lhs)
1525 return nullptr;
1526 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1527 if (!rhs)
1528 return nullptr;
1529 APInt folded = lhs.getValue() * rhs.getValue();
1530 Type indexTy = IndexType::get(getContext());
1531 return IntegerAttr::get(indexTy, folded);
1532 }
1533
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1534 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1535 MLIRContext *context, Optional<Location> location, ValueRange operands,
1536 DictionaryAttr attributes, RegionRange regions,
1537 SmallVectorImpl<Type> &inferredReturnTypes) {
1538 if (operands[0].getType().isa<SizeType>() ||
1539 operands[1].getType().isa<SizeType>())
1540 inferredReturnTypes.assign({SizeType::get(context)});
1541 else
1542 inferredReturnTypes.assign({IndexType::get(context)});
1543 return success();
1544 }
1545
isCompatibleReturnTypes(TypeRange l,TypeRange r)1546 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1547 // SizeType is compatible with IndexType.
1548 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1549 }
1550
verify()1551 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1552
1553 //===----------------------------------------------------------------------===//
1554 // ShapeOfOp
1555 //===----------------------------------------------------------------------===//
1556
fold(ArrayRef<Attribute>)1557 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1558 auto type = getOperand().getType().dyn_cast<ShapedType>();
1559 if (!type || !type.hasStaticShape())
1560 return nullptr;
1561 Builder builder(getContext());
1562 return builder.getIndexTensorAttr(type.getShape());
1563 }
1564
1565 namespace {
1566 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1567 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1568
matchAndRewrite__anonde1778d91311::ShapeOfWithTensor1569 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1570 PatternRewriter &rewriter) const override {
1571 if (!op.getArg().getType().isa<ShapedType>())
1572 return failure();
1573 if (op.getType().isa<ShapedType>())
1574 return failure();
1575
1576 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1577 op.getArg());
1578 return success();
1579 }
1580 };
1581
1582 // Canonicalize
1583 // ```
1584 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1585 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1586 // ```
1587 // to
1588 // ```
1589 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1590 // ```
1591 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1592 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1593
matchAndRewrite__anonde1778d91311::ShapeOfCastExtentTensor1594 LogicalResult matchAndRewrite(tensor::CastOp op,
1595 PatternRewriter &rewriter) const override {
1596 auto ty = op.getType().dyn_cast<RankedTensorType>();
1597 if (!ty || ty.getRank() != 1)
1598 return failure();
1599
1600 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1601 if (!shapeOfOp)
1602 return failure();
1603
1604 // Argument type must be ranked and must not conflict.
1605 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1606 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1607 return failure();
1608
1609 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1610 return success();
1611 }
1612 };
1613 } // namespace
1614
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1615 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1616 MLIRContext *context) {
1617 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1618 ExtractFromShapeOfExtentTensor>(context);
1619 }
1620
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1621 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1622 MLIRContext *context, Optional<Location> location, ValueRange operands,
1623 DictionaryAttr attributes, RegionRange regions,
1624 SmallVectorImpl<Type> &inferredReturnTypes) {
1625 if (operands[0].getType().isa<ValueShapeType>())
1626 inferredReturnTypes.assign({ShapeType::get(context)});
1627 else {
1628 auto shapedTy = operands[0].getType().cast<ShapedType>();
1629 int64_t rank =
1630 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1631 Type indexTy = IndexType::get(context);
1632 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1633 inferredReturnTypes.assign({extentTensorTy});
1634 }
1635 return success();
1636 }
1637
isCompatibleReturnTypes(TypeRange l,TypeRange r)1638 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1639 if (l.size() != 1 || r.size() != 1)
1640 return false;
1641 if (l == r)
1642 return true;
1643
1644 Type lhs = l.front();
1645 Type rhs = r.front();
1646
1647 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1648 return false;
1649
1650 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1651 // Shape type is compatible with all other valid return types.
1652 return true;
1653
1654 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1655 return true;
1656 return false;
1657 }
1658
verify()1659 LogicalResult shape::ShapeOfOp::verify() {
1660 return verifyShapeOrExtentTensorOp(*this);
1661 }
1662
1663 //===----------------------------------------------------------------------===//
1664 // SizeToIndexOp
1665 //===----------------------------------------------------------------------===//
1666
fold(ArrayRef<Attribute> operands)1667 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1668 // Constant values of both types, `shape.size` and `index`, are represented as
1669 // `IntegerAttr`s which makes constant folding simple.
1670 if (Attribute arg = operands[0])
1671 return arg;
1672 return OpFoldResult();
1673 }
1674
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1675 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1676 MLIRContext *context) {
1677 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1678 }
1679
areCastCompatible(TypeRange inputs,TypeRange outputs)1680 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1681 if (inputs.size() != 1 || outputs.size() != 1)
1682 return false;
1683 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1684 }
1685
1686 //===----------------------------------------------------------------------===//
1687 // YieldOp
1688 //===----------------------------------------------------------------------===//
1689
verify()1690 LogicalResult shape::YieldOp::verify() {
1691 auto *parentOp = (*this)->getParentOp();
1692 auto results = parentOp->getResults();
1693 auto operands = getOperands();
1694
1695 if (parentOp->getNumResults() != getNumOperands())
1696 return emitOpError() << "number of operands does not match number of "
1697 "results of its parent";
1698 for (auto e : llvm::zip(results, operands))
1699 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1700 return emitOpError() << "types mismatch between yield op and its parent";
1701
1702 return success();
1703 }
1704
1705 //===----------------------------------------------------------------------===//
1706 // SplitAtOp
1707 //===----------------------------------------------------------------------===//
1708
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1709 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1710 SmallVectorImpl<OpFoldResult> &results) {
1711 if (!operands[0] || !operands[1])
1712 return failure();
1713 auto shapeVec = llvm::to_vector<6>(
1714 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1715 auto shape = llvm::makeArrayRef(shapeVec);
1716 auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1717 // Verify that the split point is in the correct range.
1718 // TODO: Constant fold to an "error".
1719 int64_t rank = shape.size();
1720 if (-rank > splitPoint || splitPoint > rank)
1721 return failure();
1722 if (splitPoint < 0)
1723 splitPoint += shape.size();
1724 Builder builder(operands[0].getContext());
1725 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1726 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1727 return success();
1728 }
1729
1730 //===----------------------------------------------------------------------===//
1731 // ToExtentTensorOp
1732 //===----------------------------------------------------------------------===//
1733
fold(ArrayRef<Attribute> operands)1734 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1735 if (!operands[0])
1736 return OpFoldResult();
1737 Builder builder(getContext());
1738 auto shape = llvm::to_vector<6>(
1739 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1740 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1741 builder.getIndexType());
1742 return DenseIntElementsAttr::get(type, shape);
1743 }
1744
areCastCompatible(TypeRange inputs,TypeRange outputs)1745 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1746 if (inputs.size() != 1 || outputs.size() != 1)
1747 return false;
1748 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1749 if (!inputTensor.getElementType().isa<IndexType>() ||
1750 inputTensor.getRank() != 1)
1751 return false;
1752 } else if (!inputs[0].isa<ShapeType>()) {
1753 return false;
1754 }
1755
1756 TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1757 return outputTensor && outputTensor.getElementType().isa<IndexType>();
1758 }
1759
1760 //===----------------------------------------------------------------------===//
1761 // ReduceOp
1762 //===----------------------------------------------------------------------===//
1763
build(OpBuilder & builder,OperationState & result,Value shape,ValueRange initVals)1764 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1765 ValueRange initVals) {
1766 result.addOperands(shape);
1767 result.addOperands(initVals);
1768
1769 Region *bodyRegion = result.addRegion();
1770 bodyRegion->push_back(new Block);
1771 Block &bodyBlock = bodyRegion->front();
1772 bodyBlock.addArgument(builder.getIndexType(), result.location);
1773
1774 Type elementType;
1775 if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1776 elementType = tensorType.getElementType();
1777 else
1778 elementType = SizeType::get(builder.getContext());
1779 bodyBlock.addArgument(elementType, shape.getLoc());
1780
1781 for (Value initVal : initVals) {
1782 bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1783 result.addTypes(initVal.getType());
1784 }
1785 }
1786
verify()1787 LogicalResult ReduceOp::verify() {
1788 // Verify block arg types.
1789 Block &block = getRegion().front();
1790
1791 // The block takes index, extent, and aggregated values as arguments.
1792 auto blockArgsCount = getInitVals().size() + 2;
1793 if (block.getNumArguments() != blockArgsCount)
1794 return emitOpError() << "ReduceOp body is expected to have "
1795 << blockArgsCount << " arguments";
1796
1797 // The first block argument is the index and must always be of type `index`.
1798 if (!block.getArgument(0).getType().isa<IndexType>())
1799 return emitOpError(
1800 "argument 0 of ReduceOp body is expected to be of IndexType");
1801
1802 // The second block argument is the extent and must be of type `size` or
1803 // `index`, depending on whether the reduce operation is applied to a shape or
1804 // to an extent tensor.
1805 Type extentTy = block.getArgument(1).getType();
1806 if (getShape().getType().isa<ShapeType>()) {
1807 if (!extentTy.isa<SizeType>())
1808 return emitOpError("argument 1 of ReduceOp body is expected to be of "
1809 "SizeType if the ReduceOp operates on a ShapeType");
1810 } else {
1811 if (!extentTy.isa<IndexType>())
1812 return emitOpError(
1813 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1814 "ReduceOp operates on an extent tensor");
1815 }
1816
1817 for (const auto &type : llvm::enumerate(getInitVals()))
1818 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1819 return emitOpError() << "type mismatch between argument "
1820 << type.index() + 2
1821 << " of ReduceOp body and initial value "
1822 << type.index();
1823 return success();
1824 }
1825
parse(OpAsmParser & parser,OperationState & result)1826 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1827 // Parse operands.
1828 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1829 Type shapeOrExtentTensorType;
1830 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1831 OpAsmParser::Delimiter::Paren) ||
1832 parser.parseColonType(shapeOrExtentTensorType) ||
1833 parser.parseOptionalArrowTypeList(result.types))
1834 return failure();
1835
1836 // Resolve operands.
1837 auto initVals = llvm::makeArrayRef(operands).drop_front();
1838 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1839 result.operands) ||
1840 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1841 result.operands))
1842 return failure();
1843
1844 // Parse the body.
1845 Region *body = result.addRegion();
1846 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1847 return failure();
1848
1849 // Parse attributes.
1850 if (parser.parseOptionalAttrDict(result.attributes))
1851 return failure();
1852
1853 return success();
1854 }
1855
print(OpAsmPrinter & p)1856 void ReduceOp::print(OpAsmPrinter &p) {
1857 p << '(' << getShape() << ", " << getInitVals()
1858 << ") : " << getShape().getType();
1859 p.printOptionalArrowTypeList(getResultTypes());
1860 p << ' ';
1861 p.printRegion(getRegion());
1862 p.printOptionalAttrDict((*this)->getAttrs());
1863 }
1864
1865 #define GET_OP_CLASSES
1866 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1867
1868 #define GET_TYPEDEF_CLASSES
1869 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
1870