//===- Shape.cpp - MLIR Shape Operations ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::shape; ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" >(); addTypes(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and // try different variants before actually defining the op. allowUnknownOperations(); } Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto shapeType = type.dyn_cast()) { return builder.create(loc, type, value.cast()); } if (auto sizeType = type.dyn_cast()) { return builder.create(loc, type, value.cast()); } return nullptr; } /// Parse a type registered to this dialect. Type ShapeDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) return Type(); if (keyword == "component") return ComponentType::get(getContext()); if (keyword == "element") return ElementType::get(getContext()); if (keyword == "shape") return ShapeType::get(getContext()); if (keyword == "size") return SizeType::get(getContext()); if (keyword == "value_shape") return ValueShapeType::get(getContext()); if (keyword == "witness") return WitnessType::get(getContext()); parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; return Type(); } /// Print a type registered to this dialect. void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { case ShapeTypes::Component: os << "component"; return; case ShapeTypes::Element: os << "element"; return; case ShapeTypes::Size: os << "size"; return; case ShapeTypes::Shape: os << "shape"; return; case ShapeTypes::ValueShape: os << "value_shape"; return; case ShapeTypes::Witness: os << "witness"; return; default: llvm_unreachable("unexpected 'shape' type kind"); } } //===----------------------------------------------------------------------===// // AnyOp //===----------------------------------------------------------------------===// LogicalResult AnyOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } //===----------------------------------------------------------------------===// // AssumingOp //===----------------------------------------------------------------------===// static ParseResult parseAssumingOp(OpAsmParser &parser, OperationState &result) { result.regions.reserve(1); Region *doRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::OperandType cond; if (parser.parseOperand(cond) || parser.resolveOperand(cond, builder.getType(), result.operands)) return failure(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Parse the region and add a terminator if elided. if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } static void print(OpAsmPrinter &p, AssumingOp op) { bool yieldsResults = !op.results().empty(); p << AssumingOp::getOperationName() << " " << op.witness(); if (yieldsResults) { p << " -> (" << op.getResultTypes() << ")"; } p.printRegion(op.doRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/yieldsResults); p.printOptionalAttrDict(op.getAttrs()); } //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// LogicalResult BroadcastOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!operands[0] || !operands[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); SmallVector resultShape; // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// LogicalResult ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); return success(); } OpFoldResult ConcatOp::fold(ArrayRef operands) { if (!operands[0] || !operands[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); SmallVector resultShape; resultShape.append(lhsShape.begin(), lhsShape.end()); resultShape.append(rhsShape.begin(), rhsShape.end()); Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// // ConstShapeOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstShapeOp &op) { p << "shape.const_shape "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); p << "["; interleaveComma(op.shape().getValues(), p, [&](int64_t i) { p << i; }); p << "]"; } static ParseResult parseConstShapeOp(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // We piggy-back on ArrayAttr parsing, though we don't internally store the // shape as an ArrayAttr. // TODO: Implement custom parser and maybe make syntax a bit more concise. Attribute extentsRaw; NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); auto extentsArray = extentsRaw.dyn_cast(); if (!extentsArray) return failure(); SmallVector ints; for (Attribute extent : extentsArray) { IntegerAttr attr = extent.dyn_cast(); if (!attr) return failure(); ints.push_back(attr.getInt()); } Builder &builder = parser.getBuilder(); result.addAttribute("shape", builder.getIndexTensorAttr(ints)); result.types.push_back(ShapeType::get(builder.getContext())); return success(); } OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// LogicalResult ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; return {}; } LogicalResult IndexToSizeOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } //===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// LogicalResult FromExtentsOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } OpFoldResult FromExtentsOp::fold(ArrayRef operands) { if (llvm::any_of(operands, [](Attribute a) { return !a; })) return nullptr; SmallVector extents; for (auto attr : operands) extents.push_back(attr.cast().getInt()); Builder builder(getContext()); return builder.getIndexTensorAttr(extents); } //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// LogicalResult GetExtentOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } OpFoldResult GetExtentOp::fold(ArrayRef operands) { auto elements = operands[0].dyn_cast_or_null(); if (!elements) return nullptr; uint64_t dimToGet = dim().getLimitedValue(); // TODO: Constant fold this to some kind of constant error. if (dimToGet >= (uint64_t)elements.getNumElements()) return nullptr; return elements.getValue({dimToGet}); } //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// OpFoldResult NumElementsOp::fold(ArrayRef operands) { // Fold only when argument constant. Attribute shape = operands[0]; if (!shape) return {}; APInt product(64, 1); for (auto value : shape.cast()) product *= value; Builder builder(getContext()); return builder.getIndexAttr(product.getLimitedValue()); } LogicalResult NumElementsOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// LogicalResult ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } OpFoldResult ShapeOfOp::fold(ArrayRef) { auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); return builder.getIndexTensorAttr(type.getShape()); } //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; return {}; } LogicalResult SizeToIndexOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(IndexType::get(context)); return success(); } //===----------------------------------------------------------------------===// // SplitAtOp //===----------------------------------------------------------------------===// LogicalResult SplitAtOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); inferredReturnTypes.push_back(shapeType); return success(); } LogicalResult SplitAtOp::fold(ArrayRef operands, SmallVectorImpl &results) { if (!operands[0] || !operands[1]) return failure(); auto shapeVec = llvm::to_vector<6>( operands[0].cast().getValues()); auto shape = llvm::makeArrayRef(shapeVec); auto splitPoint = operands[1].cast().getInt(); // Verify that the split point is in the correct range. // TODO: Constant fold to an "error". int64_t rank = shape.size(); if (!(-rank <= splitPoint && splitPoint <= rank)) return failure(); if (splitPoint < 0) splitPoint += shape.size(); Builder builder(operands[0].getContext()); results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); return success(); } //===----------------------------------------------------------------------===// // ToExtentTensorOp //===----------------------------------------------------------------------===// OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) return nullptr; Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); auto type = RankedTensorType::get({static_cast(shape.size())}, builder.getIndexType()); return DenseIntElementsAttr::get(type, shape); } namespace mlir { namespace shape { #define GET_OP_CLASSES #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" } // namespace shape } // namespace mlir