1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 using namespace mlir::shape;
21 
22 namespace {
23 #include "ShapeCanonicalization.inc"
24 }
25 
26 ShapeDialect::ShapeDialect(MLIRContext *context)
27     : Dialect(getDialectNamespace(), context) {
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
31       >();
32   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
33            WitnessType>();
34   // Allow unknown operations during prototyping and testing. As the dialect is
35   // still evolving it makes it simple to start with an unregistered ops and
36   // try different variants before actually defining the op.
37   allowUnknownOperations();
38 }
39 
40 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
41                                              Attribute value, Type type,
42                                              Location loc) {
43   if (auto shapeType = type.dyn_cast<ShapeType>())
44     return builder.create<ConstShapeOp>(loc, type,
45                                         value.cast<DenseIntElementsAttr>());
46   if (auto sizeType = type.dyn_cast<SizeType>())
47     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
48   if (auto witnessType = type.dyn_cast<WitnessType>())
49     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
50   return nullptr;
51 }
52 
53 /// Parse a type registered to this dialect.
54 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
55   StringRef keyword;
56   if (parser.parseKeyword(&keyword))
57     return Type();
58 
59   if (keyword == "component")
60     return ComponentType::get(getContext());
61   if (keyword == "element")
62     return ElementType::get(getContext());
63   if (keyword == "shape")
64     return ShapeType::get(getContext());
65   if (keyword == "size")
66     return SizeType::get(getContext());
67   if (keyword == "value_shape")
68     return ValueShapeType::get(getContext());
69   if (keyword == "witness")
70     return WitnessType::get(getContext());
71 
72   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
73   return Type();
74 }
75 
76 /// Print a type registered to this dialect.
77 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
78   switch (type.getKind()) {
79   case ShapeTypes::Component:
80     os << "component";
81     return;
82   case ShapeTypes::Element:
83     os << "element";
84     return;
85   case ShapeTypes::Size:
86     os << "size";
87     return;
88   case ShapeTypes::Shape:
89     os << "shape";
90     return;
91   case ShapeTypes::ValueShape:
92     os << "value_shape";
93     return;
94   case ShapeTypes::Witness:
95     os << "witness";
96     return;
97   default:
98     llvm_unreachable("unexpected 'shape' type kind");
99   }
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // AnyOp
104 //===----------------------------------------------------------------------===//
105 
106 // TODO: Canonicalization should be implemented for shapes that can be
107 // determined through mixtures of the known dimensions of the inputs.
108 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
109   // Only the last operand is checked because AnyOp is commutative.
110   if (operands.back())
111     return operands.back();
112 
113   return nullptr;
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // AssumingOp
118 //===----------------------------------------------------------------------===//
119 
120 static ParseResult parseAssumingOp(OpAsmParser &parser,
121                                    OperationState &result) {
122   result.regions.reserve(1);
123   Region *doRegion = result.addRegion();
124 
125   auto &builder = parser.getBuilder();
126   OpAsmParser::OperandType cond;
127   if (parser.parseOperand(cond) ||
128       parser.resolveOperand(cond, builder.getType<WitnessType>(),
129                             result.operands))
130     return failure();
131 
132   // Parse optional results type list.
133   if (parser.parseOptionalArrowTypeList(result.types))
134     return failure();
135 
136   // Parse the region and add a terminator if elided.
137   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
138     return failure();
139   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
140 
141   // Parse the optional attribute list.
142   if (parser.parseOptionalAttrDict(result.attributes))
143     return failure();
144   return success();
145 }
146 
147 static void print(OpAsmPrinter &p, AssumingOp op) {
148   bool yieldsResults = !op.results().empty();
149 
150   p << AssumingOp::getOperationName() << " " << op.witness();
151   if (yieldsResults) {
152     p << " -> (" << op.getResultTypes() << ")";
153   }
154   p.printRegion(op.doRegion(),
155                 /*printEntryBlockArgs=*/false,
156                 /*printBlockTerminators=*/yieldsResults);
157   p.printOptionalAttrDict(op.getAttrs());
158 }
159 
160 namespace {
161 // Removes AssumingOp with a passing witness and inlines the region.
162 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
163   using OpRewritePattern<AssumingOp>::OpRewritePattern;
164 
165   LogicalResult matchAndRewrite(AssumingOp op,
166                                 PatternRewriter &rewriter) const override {
167     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
168     if (!witness || !witness.passingAttr())
169       return failure();
170 
171     AssumingOp::inlineRegionIntoParent(op, rewriter);
172     return success();
173   }
174 };
175 } // namespace
176 
177 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
178                                              MLIRContext *context) {
179   // If taking a passing witness, inline region.
180   patterns.insert<AssumingWithTrue>(context);
181 }
182 
183 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
184                                         PatternRewriter &rewriter) {
185   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
186   auto *assumingBlock = op.getBody();
187   auto initPosition = rewriter.getInsertionPoint();
188   auto *blockAfterAssuming =
189       rewriter.splitBlock(blockBeforeAssuming, initPosition);
190 
191   // Remove the AssumingOp and AssumingYieldOp.
192   auto &yieldOp = assumingBlock->back();
193   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
194   rewriter.replaceOp(op, yieldOp.getOperands());
195   rewriter.eraseOp(&yieldOp);
196 
197   // Merge blocks together as there was no branching behavior from the
198   // AssumingOp.
199   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
200   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // AssumingAllOp
205 //===----------------------------------------------------------------------===//
206 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
207   // Iterate in reverse to first handle all constant operands. They are
208   // guaranteed to be the tail of the inputs because this is commutative.
209   for (int idx = operands.size() - 1; idx >= 0; idx--) {
210     Attribute a = operands[idx];
211     // Cannot fold if any inputs are not constant;
212     if (!a)
213       return nullptr;
214 
215     // We do not need to keep statically known values after handling them in
216     // this method.
217     getOperation()->eraseOperand(idx);
218 
219     // Always false if any input is statically known false
220     if (!a.cast<BoolAttr>().getValue())
221       return a;
222   }
223   // If this is reached, all inputs were statically known passing.
224   return BoolAttr::get(true, getContext());
225 }
226 
227 static LogicalResult verify(AssumingAllOp op) {
228   // Ensure that AssumingAllOp contains at least one operand
229   if (op.getNumOperands() == 0)
230     return op.emitOpError("no operands specified");
231 
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // BroadcastOp
237 //===----------------------------------------------------------------------===//
238 
239 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
240   if (!operands[0] || !operands[1])
241     return nullptr;
242   auto lhsShape = llvm::to_vector<6>(
243       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
244   auto rhsShape = llvm::to_vector<6>(
245       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
246   SmallVector<int64_t, 6> resultShape;
247   // If the shapes are not compatible, we can't fold it.
248   // TODO: Fold to an "error".
249   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
250     return nullptr;
251   Builder builder(getContext());
252   return builder.getIndexTensorAttr(resultShape);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // ConcatOp
257 //===----------------------------------------------------------------------===//
258 
259 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
260   if (!operands[0] || !operands[1])
261     return nullptr;
262   auto lhsShape = llvm::to_vector<6>(
263       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
264   auto rhsShape = llvm::to_vector<6>(
265       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
266   SmallVector<int64_t, 6> resultShape;
267   resultShape.append(lhsShape.begin(), lhsShape.end());
268   resultShape.append(rhsShape.begin(), rhsShape.end());
269   Builder builder(getContext());
270   return builder.getIndexTensorAttr(resultShape);
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // ConstShapeOp
275 //===----------------------------------------------------------------------===//
276 
277 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
278   p << "shape.const_shape ";
279   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
280   p << "[";
281   interleaveComma(op.shape().getValues<int64_t>(), p,
282                   [&](int64_t i) { p << i; });
283   p << "]";
284 }
285 
286 static ParseResult parseConstShapeOp(OpAsmParser &parser,
287                                      OperationState &result) {
288   if (parser.parseOptionalAttrDict(result.attributes))
289     return failure();
290   // We piggy-back on ArrayAttr parsing, though we don't internally store the
291   // shape as an ArrayAttr.
292   // TODO: Implement custom parser and maybe make syntax a bit more concise.
293   Attribute extentsRaw;
294   NamedAttrList dummy;
295   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
296     return failure();
297   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
298   if (!extentsArray)
299     return failure();
300   SmallVector<int64_t, 6> ints;
301   for (Attribute extent : extentsArray) {
302     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
303     if (!attr)
304       return failure();
305     ints.push_back(attr.getInt());
306   }
307   Builder &builder = parser.getBuilder();
308   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
309 
310   result.types.push_back(ShapeType::get(builder.getContext()));
311   return success();
312 }
313 
314 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
315 
316 //===----------------------------------------------------------------------===//
317 // CstrBroadcastableOp
318 //===----------------------------------------------------------------------===//
319 
320 void CstrBroadcastableOp::getCanonicalizationPatterns(
321     OwningRewritePatternList &patterns, MLIRContext *context) {
322   // If inputs are equal, return passing witness
323   patterns.insert<CstrBroadcastableEqOps>(context);
324 }
325 
326 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
327   if (!operands[0] || !operands[1])
328     return nullptr;
329   auto lhsShape = llvm::to_vector<6>(
330       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
331   auto rhsShape = llvm::to_vector<6>(
332       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
333   SmallVector<int64_t, 6> resultShape;
334   if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
335     return BoolAttr::get(true, getContext());
336 
337   // Because a failing witness result here represents an eventual assertion
338   // failure, we do not replace it with a constant witness.
339   return nullptr;
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // CstrEqOp
344 //===----------------------------------------------------------------------===//
345 
346 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
347                                            MLIRContext *context) {
348   // If inputs are equal, return passing witness
349   patterns.insert<CstrEqEqOps>(context);
350 }
351 
352 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
353   if (llvm::all_of(operands,
354                    [&](Attribute a) { return a && a == operands[0]; }))
355     return BoolAttr::get(true, getContext());
356 
357   // Because a failing witness result here represents an eventual assertion
358   // failure, we do not try to replace it with a constant witness. Similarly, we
359   // cannot if there are any non-const inputs.
360   return nullptr;
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // ConstSizeOp
365 //===----------------------------------------------------------------------===//
366 
367 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
368 
369 void ConstSizeOp::getAsmResultNames(
370     llvm::function_ref<void(Value, StringRef)> setNameFn) {
371   SmallString<4> buffer;
372   llvm::raw_svector_ostream os(buffer);
373   os << "c" << value();
374   setNameFn(getResult(), os.str());
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // ConstWitnessOp
379 //===----------------------------------------------------------------------===//
380 
381 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
382 
383 //===----------------------------------------------------------------------===//
384 // IndexToSizeOp
385 //===----------------------------------------------------------------------===//
386 
387 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
388   // Constant values of both types, `shape.size` and `index`, are represented as
389   // `IntegerAttr`s which makes constant folding simple.
390   if (Attribute arg = operands[0])
391     return arg;
392   return {};
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // FromExtentsOp
397 //===----------------------------------------------------------------------===//
398 
399 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
400   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
401     return nullptr;
402   SmallVector<int64_t, 6> extents;
403   for (auto attr : operands)
404     extents.push_back(attr.cast<IntegerAttr>().getInt());
405   Builder builder(getContext());
406   return builder.getIndexTensorAttr(extents);
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // GetExtentOp
411 //===----------------------------------------------------------------------===//
412 
413 Optional<int64_t> GetExtentOp::getConstantDim() {
414   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
415     return constSizeOp.value().getLimitedValue();
416   }
417   return llvm::None;
418 }
419 
420 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
421   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
422   if (!elements)
423     return nullptr;
424   Optional<int64_t> dim = getConstantDim();
425   if (!dim.hasValue())
426     return nullptr;
427   if (dim.getValue() >= elements.getNumElements())
428     return nullptr;
429   return elements.getValue({(uint64_t)dim.getValue()});
430 }
431 
432 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
433                         int64_t dim) {
434   auto loc = result.location;
435   auto dimAttr = builder.getIndexAttr(dim);
436   Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
437   build(builder, result, shape, dimValue);
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // NumElementsOp
442 //===----------------------------------------------------------------------===//
443 
444 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
445 
446   // Fold only when argument constant.
447   Attribute shape = operands[0];
448   if (!shape)
449     return {};
450 
451   APInt product(64, 1);
452   for (auto value : shape.cast<DenseIntElementsAttr>())
453     product *= value;
454   Builder builder(getContext());
455   return builder.getIndexAttr(product.getLimitedValue());
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // ShapeOfOp
460 //===----------------------------------------------------------------------===//
461 
462 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
463   auto type = getOperand().getType().dyn_cast<ShapedType>();
464   if (!type || !type.hasStaticShape())
465     return nullptr;
466   Builder builder(getContext());
467   return builder.getIndexTensorAttr(type.getShape());
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // SizeToIndexOp
472 //===----------------------------------------------------------------------===//
473 
474 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
475   // Constant values of both types, `shape.size` and `index`, are represented as
476   // `IntegerAttr`s which makes constant folding simple.
477   if (Attribute arg = operands[0])
478     return arg;
479   return {};
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // YieldOp
484 //===----------------------------------------------------------------------===//
485 
486 static LogicalResult verify(YieldOp op) {
487   auto *parentOp = op.getParentOp();
488   auto results = parentOp->getResults();
489   auto operands = op.getOperands();
490 
491   if (parentOp->getNumResults() != op.getNumOperands())
492     return op.emitOpError() << "number of operands does not match number of "
493                                "results of its parent";
494   for (auto e : llvm::zip(results, operands))
495     if (std::get<0>(e).getType() != std::get<1>(e).getType())
496       return op.emitOpError()
497              << "types mismatch between yield op and its parent";
498 
499   return success();
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // SplitAtOp
504 //===----------------------------------------------------------------------===//
505 
506 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
507                               SmallVectorImpl<OpFoldResult> &results) {
508   if (!operands[0] || !operands[1])
509     return failure();
510   auto shapeVec = llvm::to_vector<6>(
511       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
512   auto shape = llvm::makeArrayRef(shapeVec);
513   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
514   // Verify that the split point is in the correct range.
515   // TODO: Constant fold to an "error".
516   int64_t rank = shape.size();
517   if (!(-rank <= splitPoint && splitPoint <= rank))
518     return failure();
519   if (splitPoint < 0)
520     splitPoint += shape.size();
521   Builder builder(operands[0].getContext());
522   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
523   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
524   return success();
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // ToExtentTensorOp
529 //===----------------------------------------------------------------------===//
530 
531 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
532   if (!operands[0])
533     return nullptr;
534   Builder builder(getContext());
535   auto shape = llvm::to_vector<6>(
536       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
537   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
538                                     builder.getIndexType());
539   return DenseIntElementsAttr::get(type, shape);
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // ReduceOp
544 //===----------------------------------------------------------------------===//
545 
546 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
547                      ValueRange initVals) {
548   result.addOperands(shape);
549   result.addOperands(initVals);
550 
551   Region *bodyRegion = result.addRegion();
552   bodyRegion->push_back(new Block);
553   Block &bodyBlock = bodyRegion->front();
554   bodyBlock.addArgument(builder.getIndexType());
555   bodyBlock.addArgument(SizeType::get(builder.getContext()));
556 
557   for (Type initValType : initVals.getTypes()) {
558     bodyBlock.addArgument(initValType);
559     result.addTypes(initValType);
560   }
561 }
562 
563 static LogicalResult verify(ReduceOp op) {
564   // Verify block arg types.
565   Block &block = op.region().front();
566 
567   auto blockArgsCount = op.initVals().size() + 2;
568   if (block.getNumArguments() != blockArgsCount)
569     return op.emitOpError() << "ReduceOp body is expected to have "
570                             << blockArgsCount << " arguments";
571 
572   if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
573     return op.emitOpError(
574         "argument 0 of ReduceOp body is expected to be of IndexType");
575 
576   if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
577     return op.emitOpError(
578         "argument 1 of ReduceOp body is expected to be of SizeType");
579 
580   for (auto type : llvm::enumerate(op.initVals()))
581     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
582       return op.emitOpError()
583              << "type mismatch between argument " << type.index() + 2
584              << " of ReduceOp body and initial value " << type.index();
585   return success();
586 }
587 
588 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
589   auto *ctx = parser.getBuilder().getContext();
590   // Parse operands.
591   SmallVector<OpAsmParser::OperandType, 3> operands;
592   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
593                               OpAsmParser::Delimiter::Paren) ||
594       parser.parseOptionalArrowTypeList(result.types))
595     return failure();
596 
597   // Resolve operands.
598   auto initVals = llvm::makeArrayRef(operands).drop_front();
599   if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
600                             result.operands) ||
601       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
602                              result.operands))
603     return failure();
604 
605   // Parse the body.
606   Region *body = result.addRegion();
607   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
608     return failure();
609 
610   // Parse attributes.
611   if (parser.parseOptionalAttrDict(result.attributes))
612     return failure();
613 
614   return success();
615 }
616 
617 static void print(OpAsmPrinter &p, ReduceOp op) {
618   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
619     << ") ";
620   p.printOptionalArrowTypeList(op.getResultTypes());
621   p.printRegion(op.region());
622   p.printOptionalAttrDict(op.getAttrs());
623 }
624 
625 namespace mlir {
626 namespace shape {
627 
628 #define GET_OP_CLASSES
629 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
630 
631 } // namespace shape
632 } // namespace mlir
633