1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Traits.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 ShapeDialect::ShapeDialect(MLIRContext *context) 22 : Dialect(getDialectNamespace(), context) { 23 addOperations< 24 #define GET_OP_LIST 25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 26 >(); 27 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 28 WitnessType>(); 29 // Allow unknown operations during prototyping and testing. As the dialect is 30 // still evolving it makes it simple to start with an unregistered ops and 31 // try different variants before actually defining the op. 32 allowUnknownOperations(); 33 } 34 35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 36 Attribute value, Type type, 37 Location loc) { 38 if (auto shapeType = type.dyn_cast<ShapeType>()) { 39 return builder.create<ConstShapeOp>(loc, type, 40 value.cast<DenseIntElementsAttr>()); 41 } 42 if (auto sizeType = type.dyn_cast<SizeType>()) { 43 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 44 } 45 return nullptr; 46 } 47 48 /// Parse a type registered to this dialect. 49 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 50 StringRef keyword; 51 if (parser.parseKeyword(&keyword)) 52 return Type(); 53 54 if (keyword == "component") 55 return ComponentType::get(getContext()); 56 if (keyword == "element") 57 return ElementType::get(getContext()); 58 if (keyword == "shape") 59 return ShapeType::get(getContext()); 60 if (keyword == "size") 61 return SizeType::get(getContext()); 62 if (keyword == "value_shape") 63 return ValueShapeType::get(getContext()); 64 if (keyword == "witness") 65 return WitnessType::get(getContext()); 66 67 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 68 return Type(); 69 } 70 71 /// Print a type registered to this dialect. 72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 73 switch (type.getKind()) { 74 case ShapeTypes::Component: 75 os << "component"; 76 return; 77 case ShapeTypes::Element: 78 os << "element"; 79 return; 80 case ShapeTypes::Size: 81 os << "size"; 82 return; 83 case ShapeTypes::Shape: 84 os << "shape"; 85 return; 86 case ShapeTypes::ValueShape: 87 os << "value_shape"; 88 return; 89 case ShapeTypes::Witness: 90 os << "witness"; 91 return; 92 default: 93 llvm_unreachable("unexpected 'shape' type kind"); 94 } 95 } 96 97 //===----------------------------------------------------------------------===// 98 // AnyOp 99 //===----------------------------------------------------------------------===// 100 101 LogicalResult 102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 103 ValueRange operands, DictionaryAttr attributes, 104 RegionRange regions, 105 SmallVectorImpl<Type> &inferredReturnTypes) { 106 inferredReturnTypes.push_back(ShapeType::get(context)); 107 return success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // AssumingOp 112 //===----------------------------------------------------------------------===// 113 114 static ParseResult parseAssumingOp(OpAsmParser &parser, 115 OperationState &result) { 116 result.regions.reserve(1); 117 Region *doRegion = result.addRegion(); 118 119 auto &builder = parser.getBuilder(); 120 OpAsmParser::OperandType cond; 121 if (parser.parseOperand(cond) || 122 parser.resolveOperand(cond, builder.getType<WitnessType>(), 123 result.operands)) 124 return failure(); 125 126 // Parse optional results type list. 127 if (parser.parseOptionalArrowTypeList(result.types)) 128 return failure(); 129 130 // Parse the region and add a terminator if elided. 131 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 132 return failure(); 133 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 134 135 // Parse the optional attribute list. 136 if (parser.parseOptionalAttrDict(result.attributes)) 137 return failure(); 138 return success(); 139 } 140 141 static void print(OpAsmPrinter &p, AssumingOp op) { 142 bool yieldsResults = !op.results().empty(); 143 144 p << AssumingOp::getOperationName() << " " << op.witness(); 145 if (yieldsResults) { 146 p << " -> (" << op.getResultTypes() << ")"; 147 } 148 p.printRegion(op.doRegion(), 149 /*printEntryBlockArgs=*/false, 150 /*printBlockTerminators=*/yieldsResults); 151 p.printOptionalAttrDict(op.getAttrs()); 152 } 153 154 //===----------------------------------------------------------------------===// 155 // BroadcastOp 156 //===----------------------------------------------------------------------===// 157 158 LogicalResult 159 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 160 ValueRange operands, DictionaryAttr attributes, 161 RegionRange regions, 162 SmallVectorImpl<Type> &inferredReturnTypes) { 163 inferredReturnTypes.push_back(ShapeType::get(context)); 164 return success(); 165 } 166 167 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 168 if (!operands[0] || !operands[1]) 169 return nullptr; 170 auto lhsShape = llvm::to_vector<6>( 171 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 172 auto rhsShape = llvm::to_vector<6>( 173 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 174 SmallVector<int64_t, 6> resultShape; 175 // If the shapes are not compatible, we can't fold it. 176 // TODO: Fold to an "error". 177 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 178 return nullptr; 179 Builder builder(getContext()); 180 return builder.getI64TensorAttr(resultShape); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // ConstShapeOp 185 //===----------------------------------------------------------------------===// 186 187 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 188 p << "shape.const_shape "; 189 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 190 p << "["; 191 interleaveComma(op.shape().getValues<int64_t>(), p, 192 [&](int64_t i) { p << i; }); 193 p << "]"; 194 } 195 196 static ParseResult parseConstShapeOp(OpAsmParser &parser, 197 OperationState &result) { 198 if (parser.parseOptionalAttrDict(result.attributes)) 199 return failure(); 200 // We piggy-back on ArrayAttr parsing, though we don't internally store the 201 // shape as an ArrayAttr. 202 // TODO: Implement custom parser and maybe make syntax a bit more concise. 203 Attribute extentsRaw; 204 NamedAttrList dummy; 205 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 206 return failure(); 207 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 208 if (!extentsArray) 209 return failure(); 210 SmallVector<int64_t, 6> ints; 211 for (Attribute extent : extentsArray) { 212 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 213 if (!attr) 214 return failure(); 215 ints.push_back(attr.getInt()); 216 } 217 Builder &builder = parser.getBuilder(); 218 result.addAttribute("shape", builder.getI64TensorAttr(ints)); 219 220 result.types.push_back(ShapeType::get(builder.getContext())); 221 return success(); 222 } 223 224 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); } 225 226 LogicalResult 227 ConstShapeOp::inferReturnTypes(MLIRContext *context, 228 Optional<Location> location, ValueRange operands, 229 DictionaryAttr attributes, RegionRange regions, 230 SmallVectorImpl<Type> &inferredReturnTypes) { 231 inferredReturnTypes.push_back(ShapeType::get(context)); 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // ConstSizeOp 237 //===----------------------------------------------------------------------===// 238 239 LogicalResult 240 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 241 ValueRange operands, DictionaryAttr attributes, 242 RegionRange regions, 243 SmallVectorImpl<Type> &inferredReturnTypes) { 244 inferredReturnTypes.push_back(SizeType::get(context)); 245 return success(); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // FromExtentsOp 250 //===----------------------------------------------------------------------===// 251 252 LogicalResult FromExtentsOp::inferReturnTypes( 253 MLIRContext *context, Optional<Location> location, ValueRange operands, 254 DictionaryAttr attributes, RegionRange regions, 255 SmallVectorImpl<Type> &inferredReturnTypes) { 256 inferredReturnTypes.push_back(ShapeType::get(context)); 257 return success(); 258 } 259 260 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 261 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 262 return nullptr; 263 SmallVector<int64_t, 6> extents; 264 for (auto attr : operands) 265 extents.push_back(attr.cast<IntegerAttr>().getInt()); 266 Builder builder(getContext()); 267 return builder.getI64TensorAttr(extents); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // ShapeOfOp 272 //===----------------------------------------------------------------------===// 273 274 LogicalResult 275 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 276 ValueRange operands, DictionaryAttr attributes, 277 RegionRange regions, 278 SmallVectorImpl<Type> &inferredReturnTypes) { 279 inferredReturnTypes.push_back(ShapeType::get(context)); 280 return success(); 281 } 282 283 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 284 auto type = getOperand().getType().dyn_cast<ShapedType>(); 285 if (!type || !type.hasStaticShape()) 286 return nullptr; 287 Builder builder(getContext()); 288 return builder.getI64TensorAttr(type.getShape()); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // SplitAtOp 293 //===----------------------------------------------------------------------===// 294 295 LogicalResult 296 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 297 ValueRange operands, DictionaryAttr attributes, 298 RegionRange regions, 299 SmallVectorImpl<Type> &inferredReturnTypes) { 300 auto shapeType = ShapeType::get(context); 301 inferredReturnTypes.push_back(shapeType); 302 inferredReturnTypes.push_back(shapeType); 303 return success(); 304 } 305 306 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 307 SmallVectorImpl<OpFoldResult> &results) { 308 if (!operands[0] || !operands[1]) 309 return failure(); 310 auto shapeVec = llvm::to_vector<6>( 311 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 312 auto shape = llvm::makeArrayRef(shapeVec); 313 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 314 // Verify that the split point is in the correct range. 315 // TODO: Constant fold to an "error". 316 int64_t rank = shape.size(); 317 if (!(-rank <= splitPoint && splitPoint <= rank)) 318 return failure(); 319 if (splitPoint < 0) 320 splitPoint += shape.size(); 321 Builder builder(operands[0].getContext()); 322 results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint))); 323 results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint))); 324 return success(); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // ConcatOp 329 //===----------------------------------------------------------------------===// 330 331 LogicalResult 332 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 333 ValueRange operands, DictionaryAttr attributes, 334 RegionRange regions, 335 SmallVectorImpl<Type> &inferredReturnTypes) { 336 auto shapeType = ShapeType::get(context); 337 inferredReturnTypes.push_back(shapeType); 338 return success(); 339 } 340 341 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 342 if (!operands[0] || !operands[1]) 343 return nullptr; 344 auto lhsShape = llvm::to_vector<6>( 345 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 346 auto rhsShape = llvm::to_vector<6>( 347 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 348 SmallVector<int64_t, 6> resultShape; 349 resultShape.append(lhsShape.begin(), lhsShape.end()); 350 resultShape.append(rhsShape.begin(), rhsShape.end()); 351 Builder builder(getContext()); 352 return builder.getI64TensorAttr(resultShape); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // ToExtentTensorOp 357 //===----------------------------------------------------------------------===// 358 359 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 360 if (!operands[0]) 361 return nullptr; 362 Builder builder(getContext()); 363 auto shape = llvm::to_vector<6>( 364 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 365 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 366 builder.getIndexType()); 367 return DenseIntElementsAttr::get(type, shape); 368 } 369 370 namespace mlir { 371 namespace shape { 372 373 #define GET_OP_CLASSES 374 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 375 376 } // namespace shape 377 } // namespace mlir 378