1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Traits.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 ShapeDialect::ShapeDialect(MLIRContext *context) 22 : Dialect(getDialectNamespace(), context) { 23 addOperations< 24 #define GET_OP_LIST 25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 26 >(); 27 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 28 WitnessType>(); 29 // Allow unknown operations during prototyping and testing. As the dialect is 30 // still evolving it makes it simple to start with an unregistered ops and 31 // try different variants before actually defining the op. 32 allowUnknownOperations(); 33 } 34 35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 36 Attribute value, Type type, 37 Location loc) { 38 if (auto shapeType = type.dyn_cast<ShapeType>()) { 39 return builder.create<ConstShapeOp>(loc, type, 40 value.cast<DenseIntElementsAttr>()); 41 } 42 if (auto sizeType = type.dyn_cast<SizeType>()) { 43 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 44 } 45 return nullptr; 46 } 47 48 /// Parse a type registered to this dialect. 49 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 50 StringRef keyword; 51 if (parser.parseKeyword(&keyword)) 52 return Type(); 53 54 if (keyword == "component") 55 return ComponentType::get(getContext()); 56 if (keyword == "element") 57 return ElementType::get(getContext()); 58 if (keyword == "shape") 59 return ShapeType::get(getContext()); 60 if (keyword == "size") 61 return SizeType::get(getContext()); 62 if (keyword == "value_shape") 63 return ValueShapeType::get(getContext()); 64 if (keyword == "witness") 65 return WitnessType::get(getContext()); 66 67 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 68 return Type(); 69 } 70 71 /// Print a type registered to this dialect. 72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 73 switch (type.getKind()) { 74 case ShapeTypes::Component: 75 os << "component"; 76 return; 77 case ShapeTypes::Element: 78 os << "element"; 79 return; 80 case ShapeTypes::Size: 81 os << "size"; 82 return; 83 case ShapeTypes::Shape: 84 os << "shape"; 85 return; 86 case ShapeTypes::ValueShape: 87 os << "value_shape"; 88 return; 89 case ShapeTypes::Witness: 90 os << "witness"; 91 return; 92 default: 93 llvm_unreachable("unexpected 'shape' type kind"); 94 } 95 } 96 97 //===----------------------------------------------------------------------===// 98 // AnyOp 99 //===----------------------------------------------------------------------===// 100 101 LogicalResult 102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 103 ValueRange operands, DictionaryAttr attributes, 104 RegionRange regions, 105 SmallVectorImpl<Type> &inferredReturnTypes) { 106 inferredReturnTypes.push_back(ShapeType::get(context)); 107 return success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // BroadcastOp 112 //===----------------------------------------------------------------------===// 113 114 LogicalResult 115 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 116 ValueRange operands, DictionaryAttr attributes, 117 RegionRange regions, 118 SmallVectorImpl<Type> &inferredReturnTypes) { 119 inferredReturnTypes.push_back(ShapeType::get(context)); 120 return success(); 121 } 122 123 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 124 if (!operands[0] || !operands[1]) 125 return nullptr; 126 auto lhsShape = llvm::to_vector<6>( 127 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 128 auto rhsShape = llvm::to_vector<6>( 129 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 130 SmallVector<int64_t, 6> resultShape; 131 // If the shapes are not compatible, we can't fold it. 132 // TODO: Fold to an "error". 133 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 134 return nullptr; 135 Builder builder(getContext()); 136 return builder.getI64TensorAttr(resultShape); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // ConstShapeOp 141 //===----------------------------------------------------------------------===// 142 143 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 144 p << "shape.const_shape "; 145 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 146 p << "["; 147 interleaveComma(op.shape().getValues<int64_t>(), p, 148 [&](int64_t i) { p << i; }); 149 p << "]"; 150 } 151 152 static ParseResult parseConstShapeOp(OpAsmParser &parser, 153 OperationState &result) { 154 if (parser.parseOptionalAttrDict(result.attributes)) 155 return failure(); 156 // We piggy-back on ArrayAttr parsing, though we don't internally store the 157 // shape as an ArrayAttr. 158 // TODO: Implement custom parser and maybe make syntax a bit more concise. 159 Attribute extentsRaw; 160 NamedAttrList dummy; 161 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 162 return failure(); 163 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 164 if (!extentsArray) 165 return failure(); 166 SmallVector<int64_t, 6> ints; 167 for (Attribute extent : extentsArray) { 168 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 169 if (!attr) 170 return failure(); 171 ints.push_back(attr.getInt()); 172 } 173 Builder &builder = parser.getBuilder(); 174 result.addAttribute("shape", builder.getI64TensorAttr(ints)); 175 176 result.types.push_back(ShapeType::get(builder.getContext())); 177 return success(); 178 } 179 180 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); } 181 182 LogicalResult 183 ConstShapeOp::inferReturnTypes(MLIRContext *context, 184 Optional<Location> location, ValueRange operands, 185 DictionaryAttr attributes, RegionRange regions, 186 SmallVectorImpl<Type> &inferredReturnTypes) { 187 inferredReturnTypes.push_back(ShapeType::get(context)); 188 return success(); 189 } 190 191 //===----------------------------------------------------------------------===// 192 // ConstSizeOp 193 //===----------------------------------------------------------------------===// 194 195 LogicalResult 196 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 197 ValueRange operands, DictionaryAttr attributes, 198 RegionRange regions, 199 SmallVectorImpl<Type> &inferredReturnTypes) { 200 inferredReturnTypes.push_back(SizeType::get(context)); 201 return success(); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // ShapeOfOp 206 //===----------------------------------------------------------------------===// 207 208 LogicalResult 209 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 210 ValueRange operands, DictionaryAttr attributes, 211 RegionRange regions, 212 SmallVectorImpl<Type> &inferredReturnTypes) { 213 inferredReturnTypes.push_back(ShapeType::get(context)); 214 return success(); 215 } 216 217 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 218 auto type = getOperand().getType().dyn_cast<ShapedType>(); 219 if (!type || !type.hasStaticShape()) 220 return nullptr; 221 Builder builder(getContext()); 222 return builder.getI64TensorAttr(type.getShape()); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // SplitAtOp 227 //===----------------------------------------------------------------------===// 228 229 LogicalResult 230 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 231 ValueRange operands, DictionaryAttr attributes, 232 RegionRange regions, 233 SmallVectorImpl<Type> &inferredReturnTypes) { 234 auto shapeType = ShapeType::get(context); 235 inferredReturnTypes.push_back(shapeType); 236 inferredReturnTypes.push_back(shapeType); 237 return success(); 238 } 239 240 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 241 SmallVectorImpl<OpFoldResult> &results) { 242 if (!operands[0] || !operands[1]) 243 return failure(); 244 auto shapeVec = llvm::to_vector<6>( 245 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 246 auto shape = llvm::makeArrayRef(shapeVec); 247 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 248 // Verify that the split point is in the correct range. 249 // TODO: Constant fold to an "error". 250 int64_t rank = shape.size(); 251 if (!(-rank <= splitPoint && splitPoint <= rank)) 252 return failure(); 253 if (splitPoint < 0) 254 splitPoint += shape.size(); 255 Builder builder(operands[0].getContext()); 256 results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint))); 257 results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint))); 258 return success(); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // ConcatOp 263 //===----------------------------------------------------------------------===// 264 265 LogicalResult 266 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 267 ValueRange operands, DictionaryAttr attributes, 268 RegionRange regions, 269 SmallVectorImpl<Type> &inferredReturnTypes) { 270 auto shapeType = ShapeType::get(context); 271 inferredReturnTypes.push_back(shapeType); 272 return success(); 273 } 274 275 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 276 if (!operands[0] || !operands[1]) 277 return nullptr; 278 auto lhsShape = llvm::to_vector<6>( 279 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 280 auto rhsShape = llvm::to_vector<6>( 281 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 282 SmallVector<int64_t, 6> resultShape; 283 resultShape.append(lhsShape.begin(), lhsShape.end()); 284 resultShape.append(rhsShape.begin(), rhsShape.end()); 285 Builder builder(getContext()); 286 return builder.getI64TensorAttr(resultShape); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // ToExtentTensorOp 291 //===----------------------------------------------------------------------===// 292 293 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 294 if (!operands[0]) 295 return nullptr; 296 Builder builder(getContext()); 297 auto shape = llvm::to_vector<6>( 298 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 299 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 300 builder.getIndexType()); 301 return DenseIntElementsAttr::get(type, shape); 302 } 303 304 namespace mlir { 305 namespace shape { 306 307 #define GET_OP_CLASSES 308 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 309 310 } // namespace shape 311 } // namespace mlir 312