1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 ShapeDialect::ShapeDialect(MLIRContext *context)
22     : Dialect(getDialectNamespace(), context) {
23   addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
26       >();
27   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
28            WitnessType>();
29   // Allow unknown operations during prototyping and testing. As the dialect is
30   // still evolving it makes it simple to start with an unregistered ops and
31   // try different variants before actually defining the op.
32   allowUnknownOperations();
33 }
34 
35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
36                                              Attribute value, Type type,
37                                              Location loc) {
38   if (auto shapeType = type.dyn_cast<ShapeType>()) {
39     return builder.create<ConstShapeOp>(loc, type,
40                                         value.cast<DenseIntElementsAttr>());
41   }
42   if (auto sizeType = type.dyn_cast<SizeType>()) {
43     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
44   }
45   return nullptr;
46 }
47 
48 /// Parse a type registered to this dialect.
49 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
50   StringRef keyword;
51   if (parser.parseKeyword(&keyword))
52     return Type();
53 
54   if (keyword == "component")
55     return ComponentType::get(getContext());
56   if (keyword == "element")
57     return ElementType::get(getContext());
58   if (keyword == "shape")
59     return ShapeType::get(getContext());
60   if (keyword == "size")
61     return SizeType::get(getContext());
62   if (keyword == "value_shape")
63     return ValueShapeType::get(getContext());
64   if (keyword == "witness")
65     return WitnessType::get(getContext());
66 
67   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
68   return Type();
69 }
70 
71 /// Print a type registered to this dialect.
72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
73   switch (type.getKind()) {
74   case ShapeTypes::Component:
75     os << "component";
76     return;
77   case ShapeTypes::Element:
78     os << "element";
79     return;
80   case ShapeTypes::Size:
81     os << "size";
82     return;
83   case ShapeTypes::Shape:
84     os << "shape";
85     return;
86   case ShapeTypes::ValueShape:
87     os << "value_shape";
88     return;
89   case ShapeTypes::Witness:
90     os << "witness";
91     return;
92   default:
93     llvm_unreachable("unexpected 'shape' type kind");
94   }
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AnyOp
99 //===----------------------------------------------------------------------===//
100 
101 LogicalResult
102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
103                         ValueRange operands, DictionaryAttr attributes,
104                         RegionRange regions,
105                         SmallVectorImpl<Type> &inferredReturnTypes) {
106   inferredReturnTypes.push_back(ShapeType::get(context));
107   return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // AssumingOp
112 //===----------------------------------------------------------------------===//
113 
114 static ParseResult parseAssumingOp(OpAsmParser &parser,
115                                    OperationState &result) {
116   result.regions.reserve(1);
117   Region *doRegion = result.addRegion();
118 
119   auto &builder = parser.getBuilder();
120   OpAsmParser::OperandType cond;
121   if (parser.parseOperand(cond) ||
122       parser.resolveOperand(cond, builder.getType<WitnessType>(),
123                             result.operands))
124     return failure();
125 
126   // Parse optional results type list.
127   if (parser.parseOptionalArrowTypeList(result.types))
128     return failure();
129 
130   // Parse the region and add a terminator if elided.
131   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
132     return failure();
133   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
134 
135   // Parse the optional attribute list.
136   if (parser.parseOptionalAttrDict(result.attributes))
137     return failure();
138   return success();
139 }
140 
141 static void print(OpAsmPrinter &p, AssumingOp op) {
142   bool yieldsResults = !op.results().empty();
143 
144   p << AssumingOp::getOperationName() << " " << op.witness();
145   if (yieldsResults) {
146     p << " -> (" << op.getResultTypes() << ")";
147   }
148   p.printRegion(op.doRegion(),
149                 /*printEntryBlockArgs=*/false,
150                 /*printBlockTerminators=*/yieldsResults);
151   p.printOptionalAttrDict(op.getAttrs());
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // BroadcastOp
156 //===----------------------------------------------------------------------===//
157 
158 LogicalResult
159 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
160                               ValueRange operands, DictionaryAttr attributes,
161                               RegionRange regions,
162                               SmallVectorImpl<Type> &inferredReturnTypes) {
163   inferredReturnTypes.push_back(ShapeType::get(context));
164   return success();
165 }
166 
167 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
168   if (!operands[0] || !operands[1])
169     return nullptr;
170   auto lhsShape = llvm::to_vector<6>(
171       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
172   auto rhsShape = llvm::to_vector<6>(
173       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
174   SmallVector<int64_t, 6> resultShape;
175   // If the shapes are not compatible, we can't fold it.
176   // TODO: Fold to an "error".
177   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
178     return nullptr;
179   Builder builder(getContext());
180   return builder.getI64TensorAttr(resultShape);
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // ConstShapeOp
185 //===----------------------------------------------------------------------===//
186 
187 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
188   p << "shape.const_shape ";
189   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
190   p << "[";
191   interleaveComma(op.shape().getValues<int64_t>(), p,
192                   [&](int64_t i) { p << i; });
193   p << "]";
194 }
195 
196 static ParseResult parseConstShapeOp(OpAsmParser &parser,
197                                      OperationState &result) {
198   if (parser.parseOptionalAttrDict(result.attributes))
199     return failure();
200   // We piggy-back on ArrayAttr parsing, though we don't internally store the
201   // shape as an ArrayAttr.
202   // TODO: Implement custom parser and maybe make syntax a bit more concise.
203   Attribute extentsRaw;
204   NamedAttrList dummy;
205   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
206     return failure();
207   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
208   if (!extentsArray)
209     return failure();
210   SmallVector<int64_t, 6> ints;
211   for (Attribute extent : extentsArray) {
212     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
213     if (!attr)
214       return failure();
215     ints.push_back(attr.getInt());
216   }
217   Builder &builder = parser.getBuilder();
218   result.addAttribute("shape", builder.getI64TensorAttr(ints));
219 
220   result.types.push_back(ShapeType::get(builder.getContext()));
221   return success();
222 }
223 
224 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
225 
226 LogicalResult
227 ConstShapeOp::inferReturnTypes(MLIRContext *context,
228                                Optional<Location> location, ValueRange operands,
229                                DictionaryAttr attributes, RegionRange regions,
230                                SmallVectorImpl<Type> &inferredReturnTypes) {
231   inferredReturnTypes.push_back(ShapeType::get(context));
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // ConstSizeOp
237 //===----------------------------------------------------------------------===//
238 
239 LogicalResult
240 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
241                               ValueRange operands, DictionaryAttr attributes,
242                               RegionRange regions,
243                               SmallVectorImpl<Type> &inferredReturnTypes) {
244   inferredReturnTypes.push_back(SizeType::get(context));
245   return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // FromExtentsOp
250 //===----------------------------------------------------------------------===//
251 
252 LogicalResult FromExtentsOp::inferReturnTypes(
253     MLIRContext *context, Optional<Location> location, ValueRange operands,
254     DictionaryAttr attributes, RegionRange regions,
255     SmallVectorImpl<Type> &inferredReturnTypes) {
256   inferredReturnTypes.push_back(ShapeType::get(context));
257   return success();
258 }
259 
260 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
261   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
262     return nullptr;
263   SmallVector<int64_t, 6> extents;
264   for (auto attr : operands)
265     extents.push_back(attr.cast<IntegerAttr>().getInt());
266   Builder builder(getContext());
267   return builder.getI64TensorAttr(extents);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ShapeOfOp
272 //===----------------------------------------------------------------------===//
273 
274 LogicalResult
275 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
276                             ValueRange operands, DictionaryAttr attributes,
277                             RegionRange regions,
278                             SmallVectorImpl<Type> &inferredReturnTypes) {
279   inferredReturnTypes.push_back(ShapeType::get(context));
280   return success();
281 }
282 
283 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
284   auto type = getOperand().getType().dyn_cast<ShapedType>();
285   if (!type || !type.hasStaticShape())
286     return nullptr;
287   Builder builder(getContext());
288   return builder.getI64TensorAttr(type.getShape());
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // SplitAtOp
293 //===----------------------------------------------------------------------===//
294 
295 LogicalResult
296 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
297                             ValueRange operands, DictionaryAttr attributes,
298                             RegionRange regions,
299                             SmallVectorImpl<Type> &inferredReturnTypes) {
300   auto shapeType = ShapeType::get(context);
301   inferredReturnTypes.push_back(shapeType);
302   inferredReturnTypes.push_back(shapeType);
303   return success();
304 }
305 
306 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
307                               SmallVectorImpl<OpFoldResult> &results) {
308   if (!operands[0] || !operands[1])
309     return failure();
310   auto shapeVec = llvm::to_vector<6>(
311       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
312   auto shape = llvm::makeArrayRef(shapeVec);
313   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
314   // Verify that the split point is in the correct range.
315   // TODO: Constant fold to an "error".
316   int64_t rank = shape.size();
317   if (!(-rank <= splitPoint && splitPoint <= rank))
318     return failure();
319   if (splitPoint < 0)
320     splitPoint += shape.size();
321   Builder builder(operands[0].getContext());
322   results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
323   results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
324   return success();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // ConcatOp
329 //===----------------------------------------------------------------------===//
330 
331 LogicalResult
332 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
333                            ValueRange operands, DictionaryAttr attributes,
334                            RegionRange regions,
335                            SmallVectorImpl<Type> &inferredReturnTypes) {
336   auto shapeType = ShapeType::get(context);
337   inferredReturnTypes.push_back(shapeType);
338   return success();
339 }
340 
341 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
342   if (!operands[0] || !operands[1])
343     return nullptr;
344   auto lhsShape = llvm::to_vector<6>(
345       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
346   auto rhsShape = llvm::to_vector<6>(
347       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
348   SmallVector<int64_t, 6> resultShape;
349   resultShape.append(lhsShape.begin(), lhsShape.end());
350   resultShape.append(rhsShape.begin(), rhsShape.end());
351   Builder builder(getContext());
352   return builder.getI64TensorAttr(resultShape);
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // ToExtentTensorOp
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
360   if (!operands[0])
361     return nullptr;
362   Builder builder(getContext());
363   auto shape = llvm::to_vector<6>(
364       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
365   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
366                                     builder.getIndexType());
367   return DenseIntElementsAttr::get(type, shape);
368 }
369 
370 namespace mlir {
371 namespace shape {
372 
373 #define GET_OP_CLASSES
374 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
375 
376 } // namespace shape
377 } // namespace mlir
378