1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Traits.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "llvm/Support/raw_ostream.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 21 ShapeDialect::ShapeDialect(MLIRContext *context) 22 : Dialect(getDialectNamespace(), context) { 23 addOperations< 24 #define GET_OP_LIST 25 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 26 >(); 27 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>(); 28 // Allow unknown operations during prototyping and testing. As the dialect is 29 // still evolving it makes it simple to start with an unregistered ops and 30 // try different variants before actually defining the op. 31 allowUnknownOperations(); 32 } 33 34 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 35 Attribute value, Type type, 36 Location loc) { 37 if (auto shapeType = type.dyn_cast<ShapeType>()) { 38 return builder.create<ConstShapeOp>(loc, type, 39 value.cast<DenseIntElementsAttr>()); 40 } 41 if (auto sizeType = type.dyn_cast<SizeType>()) { 42 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 43 } 44 return nullptr; 45 } 46 47 /// Parse a type registered to this dialect. 48 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 49 StringRef keyword; 50 if (parser.parseKeyword(&keyword)) 51 return Type(); 52 53 if (keyword == "component") 54 return ComponentType::get(getContext()); 55 if (keyword == "element") 56 return ElementType::get(getContext()); 57 if (keyword == "shape") 58 return ShapeType::get(getContext()); 59 if (keyword == "size") 60 return SizeType::get(getContext()); 61 if (keyword == "value_shape") 62 return ValueShapeType::get(getContext()); 63 64 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 65 return Type(); 66 } 67 68 /// Print a type registered to this dialect. 69 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 70 switch (type.getKind()) { 71 case ShapeTypes::Component: 72 os << "component"; 73 return; 74 case ShapeTypes::Element: 75 os << "element"; 76 return; 77 case ShapeTypes::Size: 78 os << "size"; 79 return; 80 case ShapeTypes::Shape: 81 os << "shape"; 82 return; 83 case ShapeTypes::ValueShape: 84 os << "value_shape"; 85 return; 86 default: 87 llvm_unreachable("unexpected 'shape' type kind"); 88 } 89 } 90 91 //===----------------------------------------------------------------------===// 92 // BroadcastOp 93 //===----------------------------------------------------------------------===// 94 95 LogicalResult 96 BroadcastOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 97 ValueRange operands, DictionaryAttr attributes, 98 RegionRange regions, 99 SmallVectorImpl<Type> &inferredReturnTypes) { 100 inferredReturnTypes.push_back(ShapeType::get(context)); 101 return success(); 102 } 103 104 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 105 if (!operands[0] || !operands[1]) 106 return nullptr; 107 auto lhsShape = llvm::to_vector<6>( 108 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 109 auto rhsShape = llvm::to_vector<6>( 110 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 111 SmallVector<int64_t, 6> resultShape; 112 // If the shapes are not compatible, we can't fold it. 113 // TODO: Fold to an "error". 114 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 115 return nullptr; 116 Builder builder(getContext()); 117 return builder.getI64TensorAttr(resultShape); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // ConstShapeOp 122 //===----------------------------------------------------------------------===// 123 124 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 125 p << "shape.const_shape "; 126 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 127 p << "["; 128 interleaveComma(op.shape().getValues<int64_t>(), p, 129 [&](int64_t i) { p << i; }); 130 p << "]"; 131 } 132 133 static ParseResult parseConstShapeOp(OpAsmParser &parser, 134 OperationState &result) { 135 if (parser.parseOptionalAttrDict(result.attributes)) 136 return failure(); 137 // We piggy-back on ArrayAttr parsing, though we don't internally store the 138 // shape as an ArrayAttr. 139 // TODO: Implement custom parser and maybe make syntax a bit more concise. 140 Attribute extentsRaw; 141 NamedAttrList dummy; 142 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 143 return failure(); 144 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 145 if (!extentsArray) 146 return failure(); 147 SmallVector<int64_t, 6> ints; 148 for (Attribute extent : extentsArray) { 149 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 150 if (!attr) 151 return failure(); 152 ints.push_back(attr.getInt()); 153 } 154 Builder &builder = parser.getBuilder(); 155 result.addAttribute("shape", builder.getI64TensorAttr(ints)); 156 157 result.types.push_back(ShapeType::get(builder.getContext())); 158 return success(); 159 } 160 161 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); } 162 163 LogicalResult 164 ConstShapeOp::inferReturnTypes(MLIRContext *context, 165 Optional<Location> location, ValueRange operands, 166 DictionaryAttr attributes, RegionRange regions, 167 SmallVectorImpl<Type> &inferredReturnTypes) { 168 inferredReturnTypes.push_back(ShapeType::get(context)); 169 return success(); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // ConstSizeOp 174 //===----------------------------------------------------------------------===// 175 176 LogicalResult 177 ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 178 ValueRange operands, DictionaryAttr attributes, 179 RegionRange regions, 180 SmallVectorImpl<Type> &inferredReturnTypes) { 181 inferredReturnTypes.push_back(SizeType::get(context)); 182 return success(); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // ShapeOfOp 187 //===----------------------------------------------------------------------===// 188 189 LogicalResult 190 ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 191 ValueRange operands, DictionaryAttr attributes, 192 RegionRange regions, 193 SmallVectorImpl<Type> &inferredReturnTypes) { 194 inferredReturnTypes.push_back(ShapeType::get(context)); 195 return success(); 196 } 197 198 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 199 auto type = getOperand().getType().dyn_cast<ShapedType>(); 200 if (!type || !type.hasStaticShape()) 201 return nullptr; 202 Builder builder(getContext()); 203 return builder.getI64TensorAttr(type.getShape()); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // SplitAtOp 208 //===----------------------------------------------------------------------===// 209 210 LogicalResult 211 SplitAtOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 212 ValueRange operands, DictionaryAttr attributes, 213 RegionRange regions, 214 SmallVectorImpl<Type> &inferredReturnTypes) { 215 auto shapeType = ShapeType::get(context); 216 inferredReturnTypes.push_back(shapeType); 217 inferredReturnTypes.push_back(shapeType); 218 return success(); 219 } 220 221 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 222 SmallVectorImpl<OpFoldResult> &results) { 223 if (!operands[0] || !operands[1]) 224 return failure(); 225 auto shapeVec = llvm::to_vector<6>( 226 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 227 auto shape = llvm::makeArrayRef(shapeVec); 228 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 229 // Verify that the split point is in the correct range. 230 // TODO: Constant fold to an "error". 231 int64_t rank = shape.size(); 232 if (!(-rank <= splitPoint && splitPoint <= rank)) 233 return failure(); 234 if (splitPoint < 0) 235 splitPoint += shape.size(); 236 Builder builder(operands[0].getContext()); 237 results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint))); 238 results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint))); 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // ConcatOp 244 //===----------------------------------------------------------------------===// 245 246 LogicalResult 247 ConcatOp::inferReturnTypes(MLIRContext *context, Optional<Location> location, 248 ValueRange operands, DictionaryAttr attributes, 249 RegionRange regions, 250 SmallVectorImpl<Type> &inferredReturnTypes) { 251 auto shapeType = ShapeType::get(context); 252 inferredReturnTypes.push_back(shapeType); 253 return success(); 254 } 255 256 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 257 if (!operands[0] || !operands[1]) 258 return nullptr; 259 auto lhsShape = llvm::to_vector<6>( 260 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 261 auto rhsShape = llvm::to_vector<6>( 262 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 263 SmallVector<int64_t, 6> resultShape; 264 resultShape.append(lhsShape.begin(), lhsShape.end()); 265 resultShape.append(rhsShape.begin(), rhsShape.end()); 266 Builder builder(getContext()); 267 return builder.getI64TensorAttr(resultShape); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // ToExtentTensorOp 272 //===----------------------------------------------------------------------===// 273 274 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 275 if (!operands[0]) 276 return nullptr; 277 Builder builder(getContext()); 278 auto shape = llvm::to_vector<6>( 279 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 280 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 281 builder.getIndexType()); 282 return DenseIntElementsAttr::get(type, shape); 283 } 284 285 namespace mlir { 286 namespace shape { 287 288 #define GET_OP_CLASSES 289 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 290 291 } // namespace shape 292 } // namespace mlir 293