1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 ShapeDialect::ShapeDialect(MLIRContext *context)
22     : Dialect(getDialectNamespace(), context) {
23   addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
26       >();
27   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
28   // Allow unknown operations during prototyping and testing. As the dialect is
29   // still evolving it makes it simple to start with an unregistered ops and
30   // try different variants before actually defining the op.
31   allowUnknownOperations();
32 }
33 
34 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
35                                              Attribute value, Type type,
36                                              Location loc) {
37   if (auto shapeType = type.dyn_cast<ShapeType>()) {
38     return builder.create<ConstShapeOp>(loc, type,
39                                         value.cast<DenseIntElementsAttr>());
40   }
41   if (auto sizeType = type.dyn_cast<SizeType>()) {
42     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
43   }
44   return nullptr;
45 }
46 
47 /// Parse a type registered to this dialect.
48 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
49   StringRef keyword;
50   if (parser.parseKeyword(&keyword))
51     return Type();
52 
53   if (keyword == "component")
54     return ComponentType::get(getContext());
55   if (keyword == "element")
56     return ElementType::get(getContext());
57   if (keyword == "shape")
58     return ShapeType::get(getContext());
59   if (keyword == "size")
60     return SizeType::get(getContext());
61   if (keyword == "value_shape")
62     return ValueShapeType::get(getContext());
63 
64   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
65   return Type();
66 }
67 
68 /// Print a type registered to this dialect.
69 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
70   switch (type.getKind()) {
71   case ShapeTypes::Component:
72     os << "component";
73     return;
74   case ShapeTypes::Element:
75     os << "element";
76     return;
77   case ShapeTypes::Size:
78     os << "size";
79     return;
80   case ShapeTypes::Shape:
81     os << "shape";
82     return;
83   case ShapeTypes::ValueShape:
84     os << "value_shape";
85     return;
86   default:
87     llvm_unreachable("unexpected 'shape' type kind");
88   }
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // BroadcastOp
93 //===----------------------------------------------------------------------===//
94 
95 LogicalResult
96 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
97                               ValueRange operands, DictionaryAttr attributes,
98                               RegionRange regions,
99                               SmallVectorImpl<Type> &inferredReturnTypes) {
100   inferredReturnTypes.push_back(ShapeType::get(context));
101   return success();
102 }
103 
104 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
105   if (!operands[0] || !operands[1])
106     return nullptr;
107   auto lhsShape = llvm::to_vector<6>(
108       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
109   auto rhsShape = llvm::to_vector<6>(
110       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
111   SmallVector<int64_t, 6> resultShape;
112   // If the shapes are not compatible, we can't fold it.
113   // TODO: Fold to an "error".
114   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
115     return nullptr;
116   Builder builder(getContext());
117   return builder.getI64TensorAttr(resultShape);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // ConstShapeOp
122 //===----------------------------------------------------------------------===//
123 
124 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
125   p << "shape.const_shape ";
126   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
127   p << "[";
128   interleaveComma(op.shape().getValues<int64_t>(), p,
129                   [&](int64_t i) { p << i; });
130   p << "]";
131 }
132 
133 static ParseResult parseConstShapeOp(OpAsmParser &parser,
134                                      OperationState &result) {
135   if (parser.parseOptionalAttrDict(result.attributes))
136     return failure();
137   // We piggy-back on ArrayAttr parsing, though we don't internally store the
138   // shape as an ArrayAttr.
139   // TODO: Implement custom parser and maybe make syntax a bit more concise.
140   Attribute extentsRaw;
141   NamedAttrList dummy;
142   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
143     return failure();
144   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
145   if (!extentsArray)
146     return failure();
147   SmallVector<int64_t, 6> ints;
148   for (Attribute extent : extentsArray) {
149     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
150     if (!attr)
151       return failure();
152     ints.push_back(attr.getInt());
153   }
154   Builder &builder = parser.getBuilder();
155   result.addAttribute("shape", builder.getI64TensorAttr(ints));
156 
157   result.types.push_back(ShapeType::get(builder.getContext()));
158   return success();
159 }
160 
161 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
162 
163 LogicalResult
164 ConstShapeOp::inferReturnTypes(MLIRContext *context,
165                                Optional<Location> location, ValueRange operands,
166                                DictionaryAttr attributes, RegionRange regions,
167                                SmallVectorImpl<Type> &inferredReturnTypes) {
168   inferredReturnTypes.push_back(ShapeType::get(context));
169   return success();
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // ConstSizeOp
174 //===----------------------------------------------------------------------===//
175 
176 LogicalResult
177 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
178                               ValueRange operands, DictionaryAttr attributes,
179                               RegionRange regions,
180                               SmallVectorImpl<Type> &inferredReturnTypes) {
181   inferredReturnTypes.push_back(SizeType::get(context));
182   return success();
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // ShapeOfOp
187 //===----------------------------------------------------------------------===//
188 
189 LogicalResult
190 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
191                             ValueRange operands, DictionaryAttr attributes,
192                             RegionRange regions,
193                             SmallVectorImpl<Type> &inferredReturnTypes) {
194   inferredReturnTypes.push_back(ShapeType::get(context));
195   return success();
196 }
197 
198 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
199   auto type = getOperand().getType().dyn_cast<ShapedType>();
200   if (!type || !type.hasStaticShape())
201     return nullptr;
202   Builder builder(getContext());
203   return builder.getI64TensorAttr(type.getShape());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // SplitAtOp
208 //===----------------------------------------------------------------------===//
209 
210 LogicalResult
211 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
212                             ValueRange operands, DictionaryAttr attributes,
213                             RegionRange regions,
214                             SmallVectorImpl<Type> &inferredReturnTypes) {
215   auto shapeType = ShapeType::get(context);
216   inferredReturnTypes.push_back(shapeType);
217   inferredReturnTypes.push_back(shapeType);
218   return success();
219 }
220 
221 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
222                               SmallVectorImpl<OpFoldResult> &results) {
223   if (!operands[0] || !operands[1])
224     return failure();
225   auto shapeVec = llvm::to_vector<6>(
226       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
227   auto shape = llvm::makeArrayRef(shapeVec);
228   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
229   // Verify that the split point is in the correct range.
230   // TODO: Constant fold to an "error".
231   int64_t rank = shape.size();
232   if (!(-rank <= splitPoint && splitPoint <= rank))
233     return failure();
234   if (splitPoint < 0)
235     splitPoint += shape.size();
236   Builder builder(operands[0].getContext());
237   results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
238   results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // ConcatOp
244 //===----------------------------------------------------------------------===//
245 
246 LogicalResult
247 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
248                            ValueRange operands, DictionaryAttr attributes,
249                            RegionRange regions,
250                            SmallVectorImpl<Type> &inferredReturnTypes) {
251   auto shapeType = ShapeType::get(context);
252   inferredReturnTypes.push_back(shapeType);
253   return success();
254 }
255 
256 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
257   if (!operands[0] || !operands[1])
258     return nullptr;
259   auto lhsShape = llvm::to_vector<6>(
260       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
261   auto rhsShape = llvm::to_vector<6>(
262       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
263   SmallVector<int64_t, 6> resultShape;
264   resultShape.append(lhsShape.begin(), lhsShape.end());
265   resultShape.append(rhsShape.begin(), rhsShape.end());
266   Builder builder(getContext());
267   return builder.getI64TensorAttr(resultShape);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ToExtentTensorOp
272 //===----------------------------------------------------------------------===//
273 
274 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
275   if (!operands[0])
276     return nullptr;
277   Builder builder(getContext());
278   auto shape = llvm::to_vector<6>(
279       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
280   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
281                                     builder.getIndexType());
282   return DenseIntElementsAttr::get(type, shape);
283 }
284 
285 namespace mlir {
286 namespace shape {
287 
288 #define GET_OP_CLASSES
289 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
290 
291 } // namespace shape
292 } // namespace mlir
293