1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/StandardTypes.h"
14 #include "llvm/Support/raw_ostream.h"
15 
16 using namespace mlir;
17 using namespace mlir::shape;
18 
19 ShapeDialect::ShapeDialect(MLIRContext *context)
20     : Dialect(getDialectNamespace(), context) {
21   addOperations<
22 #define GET_OP_LIST
23 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
24       >();
25   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
26   // Allow unknown operations during prototyping and testing. As the dialect is
27   // still evolving it makes it simple to start with an unregistered ops and
28   // try different variants before actually defining the op.
29   allowUnknownOperations();
30 }
31 
32 /// Parse a type registered to this dialect.
33 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
34   StringRef keyword;
35   if (parser.parseKeyword(&keyword))
36     return Type();
37 
38   if (keyword == "component")
39     return ComponentType::get(getContext());
40   if (keyword == "element")
41     return ElementType::get(getContext());
42   if (keyword == "shape")
43     return ShapeType::get(getContext());
44   if (keyword == "size")
45     return SizeType::get(getContext());
46   if (keyword == "value_shape")
47     return ValueShapeType::get(getContext());
48 
49   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
50   return Type();
51 }
52 
53 /// Print a type registered to this dialect.
54 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
55   switch (type.getKind()) {
56   case ShapeTypes::Component:
57     os << "component";
58     return;
59   case ShapeTypes::Element:
60     os << "element";
61     return;
62   case ShapeTypes::Size:
63     os << "size";
64     return;
65   case ShapeTypes::Shape:
66     os << "shape";
67     return;
68   case ShapeTypes::ValueShape:
69     os << "value_shape";
70     return;
71   default:
72     llvm_unreachable("unexpected 'shape' type kind");
73   }
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Constant*Op
78 //===----------------------------------------------------------------------===//
79 
80 static void print(OpAsmPrinter &p, ConstantOp &op) {
81   p << "shape.constant ";
82   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
83 
84   if (op.getAttrs().size() > 1)
85     p << ' ';
86   p.printAttributeWithoutType(op.value());
87   p << " : " << op.getType();
88 }
89 
90 static ParseResult parseConstantOp(OpAsmParser &parser,
91                                    OperationState &result) {
92   Attribute valueAttr;
93   if (parser.parseOptionalAttrDict(result.attributes))
94     return failure();
95   Type i64Type = parser.getBuilder().getIntegerType(64);
96   if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
97     return failure();
98 
99   Type type;
100   if (parser.parseColonType(type))
101     return failure();
102 
103   // Add the attribute type to the list.
104   return parser.addTypeToList(type, result.types);
105 }
106 
107 static LogicalResult verify(ConstantOp &op) { return success(); }
108 
109 //===----------------------------------------------------------------------===//
110 // SplitAtOp
111 //===----------------------------------------------------------------------===//
112 
113 LogicalResult SplitAtOp::inferReturnTypes(
114     MLIRContext *context, Optional<Location> location, ValueRange operands,
115     ArrayRef<NamedAttribute> attributes, RegionRange regions,
116     SmallVectorImpl<Type> &inferredReturnTypes) {
117   auto shapeType = ShapeType::get(context);
118   inferredReturnTypes.push_back(shapeType);
119   inferredReturnTypes.push_back(shapeType);
120   return success();
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // ConcatOp
125 //===----------------------------------------------------------------------===//
126 
127 LogicalResult ConcatOp::inferReturnTypes(
128     MLIRContext *context, Optional<Location> location, ValueRange operands,
129     ArrayRef<NamedAttribute> attributes, RegionRange regions,
130     SmallVectorImpl<Type> &inferredReturnTypes) {
131   auto shapeType = ShapeType::get(context);
132   inferredReturnTypes.push_back(shapeType);
133   return success();
134 }
135 
136 namespace mlir {
137 namespace shape {
138 
139 #define GET_OP_CLASSES
140 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
141 
142 } // namespace shape
143 } // namespace mlir
144