1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 ShapeDialect::ShapeDialect(MLIRContext *context)
22     : Dialect(getDialectNamespace(), context) {
23   addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
26       >();
27   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
28            WitnessType>();
29   // Allow unknown operations during prototyping and testing. As the dialect is
30   // still evolving it makes it simple to start with an unregistered ops and
31   // try different variants before actually defining the op.
32   allowUnknownOperations();
33 }
34 
35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
36                                              Attribute value, Type type,
37                                              Location loc) {
38   if (auto shapeType = type.dyn_cast<ShapeType>()) {
39     return builder.create<ConstShapeOp>(loc, type,
40                                         value.cast<DenseIntElementsAttr>());
41   }
42   if (auto sizeType = type.dyn_cast<SizeType>()) {
43     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
44   }
45   return nullptr;
46 }
47 
48 /// Parse a type registered to this dialect.
49 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
50   StringRef keyword;
51   if (parser.parseKeyword(&keyword))
52     return Type();
53 
54   if (keyword == "component")
55     return ComponentType::get(getContext());
56   if (keyword == "element")
57     return ElementType::get(getContext());
58   if (keyword == "shape")
59     return ShapeType::get(getContext());
60   if (keyword == "size")
61     return SizeType::get(getContext());
62   if (keyword == "value_shape")
63     return ValueShapeType::get(getContext());
64   if (keyword == "witness")
65     return WitnessType::get(getContext());
66 
67   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
68   return Type();
69 }
70 
71 /// Print a type registered to this dialect.
72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
73   switch (type.getKind()) {
74   case ShapeTypes::Component:
75     os << "component";
76     return;
77   case ShapeTypes::Element:
78     os << "element";
79     return;
80   case ShapeTypes::Size:
81     os << "size";
82     return;
83   case ShapeTypes::Shape:
84     os << "shape";
85     return;
86   case ShapeTypes::ValueShape:
87     os << "value_shape";
88     return;
89   case ShapeTypes::Witness:
90     os << "witness";
91     return;
92   default:
93     llvm_unreachable("unexpected 'shape' type kind");
94   }
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AnyOp
99 //===----------------------------------------------------------------------===//
100 
101 LogicalResult
102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
103                         ValueRange operands, DictionaryAttr attributes,
104                         RegionRange regions,
105                         SmallVectorImpl<Type> &inferredReturnTypes) {
106   inferredReturnTypes.push_back(ShapeType::get(context));
107   return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // BroadcastOp
112 //===----------------------------------------------------------------------===//
113 
114 LogicalResult
115 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
116                               ValueRange operands, DictionaryAttr attributes,
117                               RegionRange regions,
118                               SmallVectorImpl<Type> &inferredReturnTypes) {
119   inferredReturnTypes.push_back(ShapeType::get(context));
120   return success();
121 }
122 
123 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
124   if (!operands[0] || !operands[1])
125     return nullptr;
126   auto lhsShape = llvm::to_vector<6>(
127       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
128   auto rhsShape = llvm::to_vector<6>(
129       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
130   SmallVector<int64_t, 6> resultShape;
131   // If the shapes are not compatible, we can't fold it.
132   // TODO: Fold to an "error".
133   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
134     return nullptr;
135   Builder builder(getContext());
136   return builder.getI64TensorAttr(resultShape);
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // ConstShapeOp
141 //===----------------------------------------------------------------------===//
142 
143 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
144   p << "shape.const_shape ";
145   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
146   p << "[";
147   interleaveComma(op.shape().getValues<int64_t>(), p,
148                   [&](int64_t i) { p << i; });
149   p << "]";
150 }
151 
152 static ParseResult parseConstShapeOp(OpAsmParser &parser,
153                                      OperationState &result) {
154   if (parser.parseOptionalAttrDict(result.attributes))
155     return failure();
156   // We piggy-back on ArrayAttr parsing, though we don't internally store the
157   // shape as an ArrayAttr.
158   // TODO: Implement custom parser and maybe make syntax a bit more concise.
159   Attribute extentsRaw;
160   NamedAttrList dummy;
161   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
162     return failure();
163   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
164   if (!extentsArray)
165     return failure();
166   SmallVector<int64_t, 6> ints;
167   for (Attribute extent : extentsArray) {
168     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
169     if (!attr)
170       return failure();
171     ints.push_back(attr.getInt());
172   }
173   Builder &builder = parser.getBuilder();
174   result.addAttribute("shape", builder.getI64TensorAttr(ints));
175 
176   result.types.push_back(ShapeType::get(builder.getContext()));
177   return success();
178 }
179 
180 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
181 
182 LogicalResult
183 ConstShapeOp::inferReturnTypes(MLIRContext *context,
184                                Optional<Location> location, ValueRange operands,
185                                DictionaryAttr attributes, RegionRange regions,
186                                SmallVectorImpl<Type> &inferredReturnTypes) {
187   inferredReturnTypes.push_back(ShapeType::get(context));
188   return success();
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // ConstSizeOp
193 //===----------------------------------------------------------------------===//
194 
195 LogicalResult
196 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
197                               ValueRange operands, DictionaryAttr attributes,
198                               RegionRange regions,
199                               SmallVectorImpl<Type> &inferredReturnTypes) {
200   inferredReturnTypes.push_back(SizeType::get(context));
201   return success();
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // ShapeOfOp
206 //===----------------------------------------------------------------------===//
207 
208 LogicalResult
209 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
210                             ValueRange operands, DictionaryAttr attributes,
211                             RegionRange regions,
212                             SmallVectorImpl<Type> &inferredReturnTypes) {
213   inferredReturnTypes.push_back(ShapeType::get(context));
214   return success();
215 }
216 
217 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
218   auto type = getOperand().getType().dyn_cast<ShapedType>();
219   if (!type || !type.hasStaticShape())
220     return nullptr;
221   Builder builder(getContext());
222   return builder.getI64TensorAttr(type.getShape());
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // SplitAtOp
227 //===----------------------------------------------------------------------===//
228 
229 LogicalResult
230 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
231                             ValueRange operands, DictionaryAttr attributes,
232                             RegionRange regions,
233                             SmallVectorImpl<Type> &inferredReturnTypes) {
234   auto shapeType = ShapeType::get(context);
235   inferredReturnTypes.push_back(shapeType);
236   inferredReturnTypes.push_back(shapeType);
237   return success();
238 }
239 
240 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
241                               SmallVectorImpl<OpFoldResult> &results) {
242   if (!operands[0] || !operands[1])
243     return failure();
244   auto shapeVec = llvm::to_vector<6>(
245       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
246   auto shape = llvm::makeArrayRef(shapeVec);
247   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
248   // Verify that the split point is in the correct range.
249   // TODO: Constant fold to an "error".
250   int64_t rank = shape.size();
251   if (!(-rank <= splitPoint && splitPoint <= rank))
252     return failure();
253   if (splitPoint < 0)
254     splitPoint += shape.size();
255   Builder builder(operands[0].getContext());
256   results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
257   results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
258   return success();
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // ConcatOp
263 //===----------------------------------------------------------------------===//
264 
265 LogicalResult
266 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
267                            ValueRange operands, DictionaryAttr attributes,
268                            RegionRange regions,
269                            SmallVectorImpl<Type> &inferredReturnTypes) {
270   auto shapeType = ShapeType::get(context);
271   inferredReturnTypes.push_back(shapeType);
272   return success();
273 }
274 
275 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
276   if (!operands[0] || !operands[1])
277     return nullptr;
278   auto lhsShape = llvm::to_vector<6>(
279       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
280   auto rhsShape = llvm::to_vector<6>(
281       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
282   SmallVector<int64_t, 6> resultShape;
283   resultShape.append(lhsShape.begin(), lhsShape.end());
284   resultShape.append(rhsShape.begin(), rhsShape.end());
285   Builder builder(getContext());
286   return builder.getI64TensorAttr(resultShape);
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // ToExtentTensorOp
291 //===----------------------------------------------------------------------===//
292 
293 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
294   if (!operands[0])
295     return nullptr;
296   Builder builder(getContext());
297   auto shape = llvm::to_vector<6>(
298       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
299   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
300                                     builder.getIndexType());
301   return DenseIntElementsAttr::get(type, shape);
302 }
303 
304 namespace mlir {
305 namespace shape {
306 
307 #define GET_OP_CLASSES
308 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
309 
310 } // namespace shape
311 } // namespace mlir
312