1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Traits.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 ShapeDialect::ShapeDialect(MLIRContext *context) 22 : Dialect(getDialectNamespace(), context) { 23 addOperations< 24 #define GET_OP_LIST 25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 26 >(); 27 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 28 WitnessType>(); 29 // Allow unknown operations during prototyping and testing. As the dialect is 30 // still evolving it makes it simple to start with an unregistered ops and 31 // try different variants before actually defining the op. 32 allowUnknownOperations(); 33 } 34 35 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 36 Attribute value, Type type, 37 Location loc) { 38 if (auto shapeType = type.dyn_cast<ShapeType>()) { 39 return builder.create<ConstShapeOp>(loc, type, 40 value.cast<DenseIntElementsAttr>()); 41 } 42 if (auto sizeType = type.dyn_cast<SizeType>()) { 43 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 44 } 45 return nullptr; 46 } 47 48 /// Parse a type registered to this dialect. 49 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 50 StringRef keyword; 51 if (parser.parseKeyword(&keyword)) 52 return Type(); 53 54 if (keyword == "component") 55 return ComponentType::get(getContext()); 56 if (keyword == "element") 57 return ElementType::get(getContext()); 58 if (keyword == "shape") 59 return ShapeType::get(getContext()); 60 if (keyword == "size") 61 return SizeType::get(getContext()); 62 if (keyword == "value_shape") 63 return ValueShapeType::get(getContext()); 64 if (keyword == "witness") 65 return WitnessType::get(getContext()); 66 67 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 68 return Type(); 69 } 70 71 /// Print a type registered to this dialect. 72 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 73 switch (type.getKind()) { 74 case ShapeTypes::Component: 75 os << "component"; 76 return; 77 case ShapeTypes::Element: 78 os << "element"; 79 return; 80 case ShapeTypes::Size: 81 os << "size"; 82 return; 83 case ShapeTypes::Shape: 84 os << "shape"; 85 return; 86 case ShapeTypes::ValueShape: 87 os << "value_shape"; 88 return; 89 case ShapeTypes::Witness: 90 os << "witness"; 91 return; 92 default: 93 llvm_unreachable("unexpected 'shape' type kind"); 94 } 95 } 96 97 //===----------------------------------------------------------------------===// 98 // AnyOp 99 //===----------------------------------------------------------------------===// 100 101 LogicalResult 102 AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 103 ValueRange operands, DictionaryAttr attributes, 104 RegionRange regions, 105 SmallVectorImpl<Type> &inferredReturnTypes) { 106 inferredReturnTypes.push_back(ShapeType::get(context)); 107 return success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // AssumingOp 112 //===----------------------------------------------------------------------===// 113 114 static ParseResult parseAssumingOp(OpAsmParser &parser, 115 OperationState &result) { 116 result.regions.reserve(1); 117 Region *doRegion = result.addRegion(); 118 119 auto &builder = parser.getBuilder(); 120 OpAsmParser::OperandType cond; 121 if (parser.parseOperand(cond) || 122 parser.resolveOperand(cond, builder.getType<WitnessType>(), 123 result.operands)) 124 return failure(); 125 126 // Parse optional results type list. 127 if (parser.parseOptionalArrowTypeList(result.types)) 128 return failure(); 129 130 // Parse the region and add a terminator if elided. 131 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 132 return failure(); 133 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 134 135 // Parse the optional attribute list. 136 if (parser.parseOptionalAttrDict(result.attributes)) 137 return failure(); 138 return success(); 139 } 140 141 static void print(OpAsmPrinter &p, AssumingOp op) { 142 bool yieldsResults = !op.results().empty(); 143 144 p << AssumingOp::getOperationName() << " " << op.witness(); 145 if (yieldsResults) { 146 p << " -> (" << op.getResultTypes() << ")"; 147 } 148 p.printRegion(op.doRegion(), 149 /*printEntryBlockArgs=*/false, 150 /*printBlockTerminators=*/yieldsResults); 151 p.printOptionalAttrDict(op.getAttrs()); 152 } 153 154 //===----------------------------------------------------------------------===// 155 // BroadcastOp 156 //===----------------------------------------------------------------------===// 157 158 LogicalResult 159 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 160 ValueRange operands, DictionaryAttr attributes, 161 RegionRange regions, 162 SmallVectorImpl<Type> &inferredReturnTypes) { 163 inferredReturnTypes.push_back(ShapeType::get(context)); 164 return success(); 165 } 166 167 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 168 if (!operands[0] || !operands[1]) 169 return nullptr; 170 auto lhsShape = llvm::to_vector<6>( 171 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 172 auto rhsShape = llvm::to_vector<6>( 173 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 174 SmallVector<int64_t, 6> resultShape; 175 // If the shapes are not compatible, we can't fold it. 176 // TODO: Fold to an "error". 177 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 178 return nullptr; 179 Builder builder(getContext()); 180 return builder.getIndexTensorAttr(resultShape); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // ConcatOp 185 //===----------------------------------------------------------------------===// 186 187 LogicalResult 188 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 189 ValueRange operands, DictionaryAttr attributes, 190 RegionRange regions, 191 SmallVectorImpl<Type> &inferredReturnTypes) { 192 auto shapeType = ShapeType::get(context); 193 inferredReturnTypes.push_back(shapeType); 194 return success(); 195 } 196 197 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 198 if (!operands[0] || !operands[1]) 199 return nullptr; 200 auto lhsShape = llvm::to_vector<6>( 201 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 202 auto rhsShape = llvm::to_vector<6>( 203 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 204 SmallVector<int64_t, 6> resultShape; 205 resultShape.append(lhsShape.begin(), lhsShape.end()); 206 resultShape.append(rhsShape.begin(), rhsShape.end()); 207 Builder builder(getContext()); 208 return builder.getIndexTensorAttr(resultShape); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // ConstShapeOp 213 //===----------------------------------------------------------------------===// 214 215 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 216 p << "shape.const_shape "; 217 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 218 p << "["; 219 interleaveComma(op.shape().getValues<int64_t>(), p, 220 [&](int64_t i) { p << i; }); 221 p << "]"; 222 } 223 224 static ParseResult parseConstShapeOp(OpAsmParser &parser, 225 OperationState &result) { 226 if (parser.parseOptionalAttrDict(result.attributes)) 227 return failure(); 228 // We piggy-back on ArrayAttr parsing, though we don't internally store the 229 // shape as an ArrayAttr. 230 // TODO: Implement custom parser and maybe make syntax a bit more concise. 231 Attribute extentsRaw; 232 NamedAttrList dummy; 233 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 234 return failure(); 235 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 236 if (!extentsArray) 237 return failure(); 238 SmallVector<int64_t, 6> ints; 239 for (Attribute extent : extentsArray) { 240 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 241 if (!attr) 242 return failure(); 243 ints.push_back(attr.getInt()); 244 } 245 Builder &builder = parser.getBuilder(); 246 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 247 248 result.types.push_back(ShapeType::get(builder.getContext())); 249 return success(); 250 } 251 252 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 253 254 //===----------------------------------------------------------------------===// 255 // ConstSizeOp 256 //===----------------------------------------------------------------------===// 257 258 LogicalResult 259 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 260 ValueRange operands, DictionaryAttr attributes, 261 RegionRange regions, 262 SmallVectorImpl<Type> &inferredReturnTypes) { 263 inferredReturnTypes.push_back(SizeType::get(context)); 264 return success(); 265 } 266 267 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 268 269 //===----------------------------------------------------------------------===// 270 // IndexToSizeOp 271 //===----------------------------------------------------------------------===// 272 273 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 274 // Constant values of both types, `shape.size` and `index`, are represented as 275 // `IntegerAttr`s which makes constant folding simple. 276 if (Attribute arg = operands[0]) 277 return arg; 278 return {}; 279 } 280 281 LogicalResult IndexToSizeOp::inferReturnTypes( 282 MLIRContext *context, Optional<Location> location, ValueRange operands, 283 DictionaryAttr attributes, RegionRange regions, 284 SmallVectorImpl<Type> &inferredReturnTypes) { 285 inferredReturnTypes.push_back(SizeType::get(context)); 286 return success(); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // FromExtentsOp 291 //===----------------------------------------------------------------------===// 292 293 LogicalResult FromExtentsOp::inferReturnTypes( 294 MLIRContext *context, Optional<Location> location, ValueRange operands, 295 DictionaryAttr attributes, RegionRange regions, 296 SmallVectorImpl<Type> &inferredReturnTypes) { 297 inferredReturnTypes.push_back(ShapeType::get(context)); 298 return success(); 299 } 300 301 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 302 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 303 return nullptr; 304 SmallVector<int64_t, 6> extents; 305 for (auto attr : operands) 306 extents.push_back(attr.cast<IntegerAttr>().getInt()); 307 Builder builder(getContext()); 308 return builder.getIndexTensorAttr(extents); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // GetExtentOp 313 //===----------------------------------------------------------------------===// 314 315 LogicalResult 316 GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 317 ValueRange operands, DictionaryAttr attributes, 318 RegionRange regions, 319 SmallVectorImpl<Type> &inferredReturnTypes) { 320 inferredReturnTypes.push_back(SizeType::get(context)); 321 return success(); 322 } 323 324 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 325 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 326 if (!elements) 327 return nullptr; 328 uint64_t dimToGet = dim().getLimitedValue(); 329 // TODO: Constant fold this to some kind of constant error. 330 if (dimToGet >= (uint64_t)elements.getNumElements()) 331 return nullptr; 332 return elements.getValue({dimToGet}); 333 } 334 335 //===----------------------------------------------------------------------===// 336 // NumElementsOp 337 //===----------------------------------------------------------------------===// 338 339 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 340 341 // Fold only when argument constant. 342 Attribute shape = operands[0]; 343 if (!shape) 344 return {}; 345 346 APInt product(64, 1); 347 for (auto value : shape.cast<DenseIntElementsAttr>()) 348 product *= value; 349 Builder builder(getContext()); 350 return builder.getIndexAttr(product.getLimitedValue()); 351 } 352 353 LogicalResult NumElementsOp::inferReturnTypes( 354 MLIRContext *context, Optional<Location> location, ValueRange operands, 355 DictionaryAttr attributes, RegionRange regions, 356 SmallVectorImpl<Type> &inferredReturnTypes) { 357 inferredReturnTypes.push_back(SizeType::get(context)); 358 return success(); 359 } 360 361 //===----------------------------------------------------------------------===// 362 // ShapeOfOp 363 //===----------------------------------------------------------------------===// 364 365 LogicalResult 366 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 367 ValueRange operands, DictionaryAttr attributes, 368 RegionRange regions, 369 SmallVectorImpl<Type> &inferredReturnTypes) { 370 inferredReturnTypes.push_back(ShapeType::get(context)); 371 return success(); 372 } 373 374 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 375 auto type = getOperand().getType().dyn_cast<ShapedType>(); 376 if (!type || !type.hasStaticShape()) 377 return nullptr; 378 Builder builder(getContext()); 379 return builder.getIndexTensorAttr(type.getShape()); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // SizeToIndexOp 384 //===----------------------------------------------------------------------===// 385 386 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 387 // Constant values of both types, `shape.size` and `index`, are represented as 388 // `IntegerAttr`s which makes constant folding simple. 389 if (Attribute arg = operands[0]) 390 return arg; 391 return {}; 392 } 393 394 LogicalResult SizeToIndexOp::inferReturnTypes( 395 MLIRContext *context, Optional<Location> location, ValueRange operands, 396 DictionaryAttr attributes, RegionRange regions, 397 SmallVectorImpl<Type> &inferredReturnTypes) { 398 inferredReturnTypes.push_back(IndexType::get(context)); 399 return success(); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // SplitAtOp 404 //===----------------------------------------------------------------------===// 405 406 LogicalResult 407 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 408 ValueRange operands, DictionaryAttr attributes, 409 RegionRange regions, 410 SmallVectorImpl<Type> &inferredReturnTypes) { 411 auto shapeType = ShapeType::get(context); 412 inferredReturnTypes.push_back(shapeType); 413 inferredReturnTypes.push_back(shapeType); 414 return success(); 415 } 416 417 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 418 SmallVectorImpl<OpFoldResult> &results) { 419 if (!operands[0] || !operands[1]) 420 return failure(); 421 auto shapeVec = llvm::to_vector<6>( 422 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 423 auto shape = llvm::makeArrayRef(shapeVec); 424 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 425 // Verify that the split point is in the correct range. 426 // TODO: Constant fold to an "error". 427 int64_t rank = shape.size(); 428 if (!(-rank <= splitPoint && splitPoint <= rank)) 429 return failure(); 430 if (splitPoint < 0) 431 splitPoint += shape.size(); 432 Builder builder(operands[0].getContext()); 433 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 434 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 435 return success(); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // ToExtentTensorOp 440 //===----------------------------------------------------------------------===// 441 442 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 443 if (!operands[0]) 444 return nullptr; 445 Builder builder(getContext()); 446 auto shape = llvm::to_vector<6>( 447 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 448 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 449 builder.getIndexType()); 450 return DenseIntElementsAttr::get(type, shape); 451 } 452 453 namespace mlir { 454 namespace shape { 455 456 #define GET_OP_CLASSES 457 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 458 459 } // namespace shape 460 } // namespace mlir 461