1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Traits.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "llvm/Support/raw_ostream.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 ShapeDialect::ShapeDialect(MLIRContext *context)
22     : Dialect(getDialectNamespace(), context) {
23   addOperations<
24 #define GET_OP_LIST
25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
26       >();
27   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
28            WitnessType>();
29   // Allow unknown operations during prototyping and testing. As the dialect is
30   // still evolving it makes it simple to start with an unregistered ops and
31   // try different variants before actually defining the op.
32   allowUnknownOperations();
33 }
34 
35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
36                                              Attribute value, Type type,
37                                              Location loc) {
38   if (auto shapeType = type.dyn_cast<ShapeType>()) {
39     return builder.create<ConstShapeOp>(loc, type,
40                                         value.cast<DenseIntElementsAttr>());
41   }
42   if (auto sizeType = type.dyn_cast<SizeType>()) {
43     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
44   }
45   return nullptr;
46 }
47 
48 /// Parse a type registered to this dialect.
49 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
50   StringRef keyword;
51   if (parser.parseKeyword(&keyword))
52     return Type();
53 
54   if (keyword == "component")
55     return ComponentType::get(getContext());
56   if (keyword == "element")
57     return ElementType::get(getContext());
58   if (keyword == "shape")
59     return ShapeType::get(getContext());
60   if (keyword == "size")
61     return SizeType::get(getContext());
62   if (keyword == "value_shape")
63     return ValueShapeType::get(getContext());
64   if (keyword == "witness")
65     return WitnessType::get(getContext());
66 
67   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
68   return Type();
69 }
70 
71 /// Print a type registered to this dialect.
72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
73   switch (type.getKind()) {
74   case ShapeTypes::Component:
75     os << "component";
76     return;
77   case ShapeTypes::Element:
78     os << "element";
79     return;
80   case ShapeTypes::Size:
81     os << "size";
82     return;
83   case ShapeTypes::Shape:
84     os << "shape";
85     return;
86   case ShapeTypes::ValueShape:
87     os << "value_shape";
88     return;
89   case ShapeTypes::Witness:
90     os << "witness";
91     return;
92   default:
93     llvm_unreachable("unexpected 'shape' type kind");
94   }
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AnyOp
99 //===----------------------------------------------------------------------===//
100 
101 LogicalResult
102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
103                         ValueRange operands, DictionaryAttr attributes,
104                         RegionRange regions,
105                         SmallVectorImpl<Type> &inferredReturnTypes) {
106   inferredReturnTypes.push_back(ShapeType::get(context));
107   return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // AssumingOp
112 //===----------------------------------------------------------------------===//
113 
114 static ParseResult parseAssumingOp(OpAsmParser &parser,
115                                    OperationState &result) {
116   result.regions.reserve(1);
117   Region *doRegion = result.addRegion();
118 
119   auto &builder = parser.getBuilder();
120   OpAsmParser::OperandType cond;
121   if (parser.parseOperand(cond) ||
122       parser.resolveOperand(cond, builder.getType<WitnessType>(),
123                             result.operands))
124     return failure();
125 
126   // Parse optional results type list.
127   if (parser.parseOptionalArrowTypeList(result.types))
128     return failure();
129 
130   // Parse the region and add a terminator if elided.
131   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
132     return failure();
133   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
134 
135   // Parse the optional attribute list.
136   if (parser.parseOptionalAttrDict(result.attributes))
137     return failure();
138   return success();
139 }
140 
141 static void print(OpAsmPrinter &p, AssumingOp op) {
142   bool yieldsResults = !op.results().empty();
143 
144   p << AssumingOp::getOperationName() << " " << op.witness();
145   if (yieldsResults) {
146     p << " -> (" << op.getResultTypes() << ")";
147   }
148   p.printRegion(op.doRegion(),
149                 /*printEntryBlockArgs=*/false,
150                 /*printBlockTerminators=*/yieldsResults);
151   p.printOptionalAttrDict(op.getAttrs());
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // BroadcastOp
156 //===----------------------------------------------------------------------===//
157 
158 LogicalResult
159 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
160                               ValueRange operands, DictionaryAttr attributes,
161                               RegionRange regions,
162                               SmallVectorImpl<Type> &inferredReturnTypes) {
163   inferredReturnTypes.push_back(ShapeType::get(context));
164   return success();
165 }
166 
167 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
168   if (!operands[0] || !operands[1])
169     return nullptr;
170   auto lhsShape = llvm::to_vector<6>(
171       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
172   auto rhsShape = llvm::to_vector<6>(
173       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
174   SmallVector<int64_t, 6> resultShape;
175   // If the shapes are not compatible, we can't fold it.
176   // TODO: Fold to an "error".
177   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
178     return nullptr;
179   Builder builder(getContext());
180   return builder.getIndexTensorAttr(resultShape);
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // ConcatOp
185 //===----------------------------------------------------------------------===//
186 
187 LogicalResult
188 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
189                            ValueRange operands, DictionaryAttr attributes,
190                            RegionRange regions,
191                            SmallVectorImpl<Type> &inferredReturnTypes) {
192   auto shapeType = ShapeType::get(context);
193   inferredReturnTypes.push_back(shapeType);
194   return success();
195 }
196 
197 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
198   if (!operands[0] || !operands[1])
199     return nullptr;
200   auto lhsShape = llvm::to_vector<6>(
201       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
202   auto rhsShape = llvm::to_vector<6>(
203       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
204   SmallVector<int64_t, 6> resultShape;
205   resultShape.append(lhsShape.begin(), lhsShape.end());
206   resultShape.append(rhsShape.begin(), rhsShape.end());
207   Builder builder(getContext());
208   return builder.getIndexTensorAttr(resultShape);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // ConstShapeOp
213 //===----------------------------------------------------------------------===//
214 
215 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
216   p << "shape.const_shape ";
217   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
218   p << "[";
219   interleaveComma(op.shape().getValues<int64_t>(), p,
220                   [&](int64_t i) { p << i; });
221   p << "]";
222 }
223 
224 static ParseResult parseConstShapeOp(OpAsmParser &parser,
225                                      OperationState &result) {
226   if (parser.parseOptionalAttrDict(result.attributes))
227     return failure();
228   // We piggy-back on ArrayAttr parsing, though we don't internally store the
229   // shape as an ArrayAttr.
230   // TODO: Implement custom parser and maybe make syntax a bit more concise.
231   Attribute extentsRaw;
232   NamedAttrList dummy;
233   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
234     return failure();
235   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
236   if (!extentsArray)
237     return failure();
238   SmallVector<int64_t, 6> ints;
239   for (Attribute extent : extentsArray) {
240     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
241     if (!attr)
242       return failure();
243     ints.push_back(attr.getInt());
244   }
245   Builder &builder = parser.getBuilder();
246   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
247 
248   result.types.push_back(ShapeType::get(builder.getContext()));
249   return success();
250 }
251 
252 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
253 
254 //===----------------------------------------------------------------------===//
255 // ConstSizeOp
256 //===----------------------------------------------------------------------===//
257 
258 LogicalResult
259 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
260                               ValueRange operands, DictionaryAttr attributes,
261                               RegionRange regions,
262                               SmallVectorImpl<Type> &inferredReturnTypes) {
263   inferredReturnTypes.push_back(SizeType::get(context));
264   return success();
265 }
266 
267 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
268 
269 //===----------------------------------------------------------------------===//
270 // IndexToSizeOp
271 //===----------------------------------------------------------------------===//
272 
273 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
274   // Constant values of both types, `shape.size` and `index`, are represented as
275   // `IntegerAttr`s which makes constant folding simple.
276   if (Attribute arg = operands[0])
277     return arg;
278   return {};
279 }
280 
281 LogicalResult IndexToSizeOp::inferReturnTypes(
282     MLIRContext *context, Optional<Location> location, ValueRange operands,
283     DictionaryAttr attributes, RegionRange regions,
284     SmallVectorImpl<Type> &inferredReturnTypes) {
285   inferredReturnTypes.push_back(SizeType::get(context));
286   return success();
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // FromExtentsOp
291 //===----------------------------------------------------------------------===//
292 
293 LogicalResult FromExtentsOp::inferReturnTypes(
294     MLIRContext *context, Optional<Location> location, ValueRange operands,
295     DictionaryAttr attributes, RegionRange regions,
296     SmallVectorImpl<Type> &inferredReturnTypes) {
297   inferredReturnTypes.push_back(ShapeType::get(context));
298   return success();
299 }
300 
301 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
302   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
303     return nullptr;
304   SmallVector<int64_t, 6> extents;
305   for (auto attr : operands)
306     extents.push_back(attr.cast<IntegerAttr>().getInt());
307   Builder builder(getContext());
308   return builder.getIndexTensorAttr(extents);
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // GetExtentOp
313 //===----------------------------------------------------------------------===//
314 
315 LogicalResult
316 GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
317                               ValueRange operands, DictionaryAttr attributes,
318                               RegionRange regions,
319                               SmallVectorImpl<Type> &inferredReturnTypes) {
320   inferredReturnTypes.push_back(SizeType::get(context));
321   return success();
322 }
323 
324 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
325   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
326   if (!elements)
327     return nullptr;
328   uint64_t dimToGet = dim().getLimitedValue();
329   // TODO: Constant fold this to some kind of constant error.
330   if (dimToGet >= (uint64_t)elements.getNumElements())
331     return nullptr;
332   return elements.getValue({dimToGet});
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // NumElementsOp
337 //===----------------------------------------------------------------------===//
338 
339 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
340 
341   // Fold only when argument constant.
342   Attribute shape = operands[0];
343   if (!shape)
344     return {};
345 
346   APInt product(64, 1);
347   for (auto value : shape.cast<DenseIntElementsAttr>())
348     product *= value;
349   Builder builder(getContext());
350   return builder.getIndexAttr(product.getLimitedValue());
351 }
352 
353 LogicalResult NumElementsOp::inferReturnTypes(
354     MLIRContext *context, Optional<Location> location, ValueRange operands,
355     DictionaryAttr attributes, RegionRange regions,
356     SmallVectorImpl<Type> &inferredReturnTypes) {
357   inferredReturnTypes.push_back(SizeType::get(context));
358   return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // ShapeOfOp
363 //===----------------------------------------------------------------------===//
364 
365 LogicalResult
366 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
367                             ValueRange operands, DictionaryAttr attributes,
368                             RegionRange regions,
369                             SmallVectorImpl<Type> &inferredReturnTypes) {
370   inferredReturnTypes.push_back(ShapeType::get(context));
371   return success();
372 }
373 
374 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
375   auto type = getOperand().getType().dyn_cast<ShapedType>();
376   if (!type || !type.hasStaticShape())
377     return nullptr;
378   Builder builder(getContext());
379   return builder.getIndexTensorAttr(type.getShape());
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // SizeToIndexOp
384 //===----------------------------------------------------------------------===//
385 
386 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
387   // Constant values of both types, `shape.size` and `index`, are represented as
388   // `IntegerAttr`s which makes constant folding simple.
389   if (Attribute arg = operands[0])
390     return arg;
391   return {};
392 }
393 
394 LogicalResult SizeToIndexOp::inferReturnTypes(
395     MLIRContext *context, Optional<Location> location, ValueRange operands,
396     DictionaryAttr attributes, RegionRange regions,
397     SmallVectorImpl<Type> &inferredReturnTypes) {
398   inferredReturnTypes.push_back(IndexType::get(context));
399   return success();
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // SplitAtOp
404 //===----------------------------------------------------------------------===//
405 
406 LogicalResult
407 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
408                             ValueRange operands, DictionaryAttr attributes,
409                             RegionRange regions,
410                             SmallVectorImpl<Type> &inferredReturnTypes) {
411   auto shapeType = ShapeType::get(context);
412   inferredReturnTypes.push_back(shapeType);
413   inferredReturnTypes.push_back(shapeType);
414   return success();
415 }
416 
417 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
418                               SmallVectorImpl<OpFoldResult> &results) {
419   if (!operands[0] || !operands[1])
420     return failure();
421   auto shapeVec = llvm::to_vector<6>(
422       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
423   auto shape = llvm::makeArrayRef(shapeVec);
424   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
425   // Verify that the split point is in the correct range.
426   // TODO: Constant fold to an "error".
427   int64_t rank = shape.size();
428   if (!(-rank <= splitPoint && splitPoint <= rank))
429     return failure();
430   if (splitPoint < 0)
431     splitPoint += shape.size();
432   Builder builder(operands[0].getContext());
433   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
434   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
435   return success();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // ToExtentTensorOp
440 //===----------------------------------------------------------------------===//
441 
442 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
443   if (!operands[0])
444     return nullptr;
445   Builder builder(getContext());
446   auto shape = llvm::to_vector<6>(
447       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
448   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
449                                     builder.getIndexType());
450   return DenseIntElementsAttr::get(type, shape);
451 }
452 
453 namespace mlir {
454 namespace shape {
455 
456 #define GET_OP_CLASSES
457 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
458 
459 } // namespace shape
460 } // namespace mlir
461