1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/StandardTypes.h" 14 #include "llvm/Support/raw_ostream.h" 15 16 using namespace mlir; 17 using namespace mlir::shape; 18 19 ShapeDialect::ShapeDialect(MLIRContext *context) 20 : Dialect(getDialectNamespace(), context) { 21 addOperations< 22 #define GET_OP_LIST 23 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 24 >(); 25 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>(); 26 // Allow unknown operations during prototyping and testing. As the dialect is 27 // still evolving it makes it simple to start with an unregistered ops and 28 // try different variants before actually defining the op. 29 allowUnknownOperations(); 30 } 31 32 /// Parse a type registered to this dialect. 33 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 34 StringRef keyword; 35 if (parser.parseKeyword(&keyword)) 36 return Type(); 37 38 if (keyword == "component") 39 return ComponentType::get(getContext()); 40 if (keyword == "element") 41 return ElementType::get(getContext()); 42 if (keyword == "shape") 43 return ShapeType::get(getContext()); 44 if (keyword == "size") 45 return SizeType::get(getContext()); 46 if (keyword == "value_shape") 47 return ValueShapeType::get(getContext()); 48 49 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 50 return Type(); 51 } 52 53 /// Print a type registered to this dialect. 54 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 55 switch (type.getKind()) { 56 case ShapeTypes::Component: 57 os << "component"; 58 return; 59 case ShapeTypes::Element: 60 os << "element"; 61 return; 62 case ShapeTypes::Size: 63 os << "size"; 64 return; 65 case ShapeTypes::Shape: 66 os << "shape"; 67 return; 68 case ShapeTypes::ValueShape: 69 os << "value_shape"; 70 return; 71 default: 72 llvm_unreachable("unexpected 'shape' type kind"); 73 } 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Constant*Op 78 //===----------------------------------------------------------------------===// 79 80 static void print(OpAsmPrinter &p, ConstantOp &op) { 81 p << "shape.constant "; 82 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); 83 84 if (op.getAttrs().size() > 1) 85 p << ' '; 86 p.printAttributeWithoutType(op.value()); 87 p << " : " << op.getType(); 88 } 89 90 static ParseResult parseConstantOp(OpAsmParser &parser, 91 OperationState &result) { 92 Attribute valueAttr; 93 if (parser.parseOptionalAttrDict(result.attributes)) 94 return failure(); 95 Type i64Type = parser.getBuilder().getIntegerType(64); 96 if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes)) 97 return failure(); 98 99 Type type; 100 if (parser.parseColonType(type)) 101 return failure(); 102 103 // Add the attribute type to the list. 104 return parser.addTypeToList(type, result.types); 105 } 106 107 static LogicalResult verify(ConstantOp &op) { return success(); } 108 109 //===----------------------------------------------------------------------===// 110 // SplitAtOp 111 //===----------------------------------------------------------------------===// 112 113 LogicalResult SplitAtOp::inferReturnTypes( 114 MLIRContext *context, Optional<Location> location, ValueRange operands, 115 ArrayRef<NamedAttribute> attributes, RegionRange regions, 116 SmallVectorImpl<Type> &inferredReturnTypes) { 117 auto shapeType = ShapeType::get(context); 118 inferredReturnTypes.push_back(shapeType); 119 inferredReturnTypes.push_back(shapeType); 120 return success(); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // ConcatOp 125 //===----------------------------------------------------------------------===// 126 127 LogicalResult ConcatOp::inferReturnTypes( 128 MLIRContext *context, Optional<Location> location, ValueRange operands, 129 ArrayRef<NamedAttribute> attributes, RegionRange regions, 130 SmallVectorImpl<Type> &inferredReturnTypes) { 131 auto shapeType = ShapeType::get(context); 132 inferredReturnTypes.push_back(shapeType); 133 return success(); 134 } 135 136 namespace mlir { 137 namespace shape { 138 139 #define GET_OP_CLASSES 140 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 141 142 } // namespace shape 143 } // namespace mlir 144