17ce1e7abSRiver Riddle //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle //
97ce1e7abSRiver Riddle // This file contains the definitions of the infer op interfaces defined in
107ce1e7abSRiver Riddle // `InferTypeOpInterface.td`.
117ce1e7abSRiver Riddle //
127ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
137ce1e7abSRiver Riddle
147ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
1609349303SJacques Pienaar #include "mlir/IR/Matchers.h"
1709349303SJacques Pienaar #include "llvm/Support/FormatVariadic.h"
187ce1e7abSRiver Riddle
197ce1e7abSRiver Riddle using namespace mlir;
207ce1e7abSRiver Riddle
217ce1e7abSRiver Riddle namespace mlir {
227ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
237ce1e7abSRiver Riddle } // namespace mlir
247ce1e7abSRiver Riddle
hasRank() const2509349303SJacques Pienaar bool ShapeAdaptor::hasRank() const {
2609349303SJacques Pienaar if (val.isNull())
2709349303SJacques Pienaar return false;
2809349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
2909349303SJacques Pienaar return t.cast<ShapedType>().hasRank();
3009349303SJacques Pienaar if (val.is<Attribute>())
3109349303SJacques Pienaar return true;
3209349303SJacques Pienaar return val.get<ShapedTypeComponents *>()->hasRank();
3309349303SJacques Pienaar }
3409349303SJacques Pienaar
getElementType() const3509349303SJacques Pienaar Type ShapeAdaptor::getElementType() const {
3609349303SJacques Pienaar if (val.isNull())
3709349303SJacques Pienaar return nullptr;
3809349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
3909349303SJacques Pienaar return t.cast<ShapedType>().getElementType();
4009349303SJacques Pienaar if (val.is<Attribute>())
4109349303SJacques Pienaar return nullptr;
4209349303SJacques Pienaar return val.get<ShapedTypeComponents *>()->getElementType();
4309349303SJacques Pienaar }
4409349303SJacques Pienaar
getDims(SmallVectorImpl<int64_t> & res) const4509349303SJacques Pienaar void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
4609349303SJacques Pienaar assert(hasRank());
4709349303SJacques Pienaar if (auto t = val.dyn_cast<Type>()) {
4809349303SJacques Pienaar ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
4909349303SJacques Pienaar res.assign(vals.begin(), vals.end());
5009349303SJacques Pienaar } else if (auto attr = val.dyn_cast<Attribute>()) {
5109349303SJacques Pienaar auto dattr = attr.cast<DenseIntElementsAttr>();
5209349303SJacques Pienaar res.clear();
5309349303SJacques Pienaar res.reserve(dattr.size());
540cb5d7fcSRiver Riddle for (auto it : dattr.getValues<APInt>())
5509349303SJacques Pienaar res.push_back(it.getSExtValue());
5609349303SJacques Pienaar } else {
5709349303SJacques Pienaar auto vals = val.get<ShapedTypeComponents *>()->getDims();
5809349303SJacques Pienaar res.assign(vals.begin(), vals.end());
5909349303SJacques Pienaar }
6009349303SJacques Pienaar }
6109349303SJacques Pienaar
getDims(ShapedTypeComponents & res) const6209349303SJacques Pienaar void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
6309349303SJacques Pienaar assert(hasRank());
6409349303SJacques Pienaar res.ranked = true;
6509349303SJacques Pienaar getDims(res.dims);
6609349303SJacques Pienaar }
6709349303SJacques Pienaar
getDimSize(int index) const6809349303SJacques Pienaar int64_t ShapeAdaptor::getDimSize(int index) const {
6909349303SJacques Pienaar assert(hasRank());
7009349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
7109349303SJacques Pienaar return t.cast<ShapedType>().getDimSize(index);
7209349303SJacques Pienaar if (auto attr = val.dyn_cast<Attribute>())
7309349303SJacques Pienaar return attr.cast<DenseIntElementsAttr>()
74ae40d625SRiver Riddle .getValues<APInt>()[index]
7509349303SJacques Pienaar .getSExtValue();
7609349303SJacques Pienaar auto *stc = val.get<ShapedTypeComponents *>();
7709349303SJacques Pienaar return stc->getDims()[index];
7809349303SJacques Pienaar }
7909349303SJacques Pienaar
getRank() const8009349303SJacques Pienaar int64_t ShapeAdaptor::getRank() const {
8109349303SJacques Pienaar assert(hasRank());
8209349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
8309349303SJacques Pienaar return t.cast<ShapedType>().getRank();
8409349303SJacques Pienaar if (auto attr = val.dyn_cast<Attribute>())
8509349303SJacques Pienaar return attr.cast<DenseIntElementsAttr>().size();
8609349303SJacques Pienaar return val.get<ShapedTypeComponents *>()->getDims().size();
8709349303SJacques Pienaar }
8809349303SJacques Pienaar
hasStaticShape() const8909349303SJacques Pienaar bool ShapeAdaptor::hasStaticShape() const {
9009349303SJacques Pienaar if (!hasRank())
9109349303SJacques Pienaar return false;
9209349303SJacques Pienaar
9309349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
9409349303SJacques Pienaar return t.cast<ShapedType>().hasStaticShape();
9509349303SJacques Pienaar if (auto attr = val.dyn_cast<Attribute>()) {
9609349303SJacques Pienaar auto dattr = attr.cast<DenseIntElementsAttr>();
970cb5d7fcSRiver Riddle for (auto index : dattr.getValues<APInt>())
9809349303SJacques Pienaar if (ShapedType::isDynamic(index.getSExtValue()))
9909349303SJacques Pienaar return false;
10009349303SJacques Pienaar return true;
10109349303SJacques Pienaar }
10209349303SJacques Pienaar auto *stc = val.get<ShapedTypeComponents *>();
103*380a1b20SKazu Hirata return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
10409349303SJacques Pienaar }
10509349303SJacques Pienaar
getNumElements() const10609349303SJacques Pienaar int64_t ShapeAdaptor::getNumElements() const {
10709349303SJacques Pienaar assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
10809349303SJacques Pienaar
10909349303SJacques Pienaar if (auto t = val.dyn_cast<Type>())
11009349303SJacques Pienaar return t.cast<ShapedType>().getNumElements();
11109349303SJacques Pienaar
11209349303SJacques Pienaar if (auto attr = val.dyn_cast<Attribute>()) {
11309349303SJacques Pienaar auto dattr = attr.cast<DenseIntElementsAttr>();
11409349303SJacques Pienaar int64_t num = 1;
1150cb5d7fcSRiver Riddle for (auto index : dattr.getValues<APInt>()) {
11609349303SJacques Pienaar num *= index.getZExtValue();
11709349303SJacques Pienaar assert(num >= 0 && "integer overflow in element count computation");
11809349303SJacques Pienaar }
11909349303SJacques Pienaar return num;
12009349303SJacques Pienaar }
12109349303SJacques Pienaar
12209349303SJacques Pienaar auto *stc = val.get<ShapedTypeComponents *>();
12309349303SJacques Pienaar int64_t num = 1;
12409349303SJacques Pienaar for (int64_t dim : stc->getDims()) {
12509349303SJacques Pienaar num *= dim;
12609349303SJacques Pienaar assert(num >= 0 && "integer overflow in element count computation");
12709349303SJacques Pienaar }
12809349303SJacques Pienaar return num;
12909349303SJacques Pienaar }
13009349303SJacques Pienaar
dump() const13109349303SJacques Pienaar void ShapeAdaptor::dump() const {
13209349303SJacques Pienaar if (!hasRank()) {
13309349303SJacques Pienaar llvm::errs() << "<<unranked>>\n";
13409349303SJacques Pienaar return;
13509349303SJacques Pienaar }
13609349303SJacques Pienaar
13709349303SJacques Pienaar SmallVector<int64_t> dims;
13809349303SJacques Pienaar getDims(dims);
13909349303SJacques Pienaar auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
14009349303SJacques Pienaar if (ShapedType::isDynamic(dim))
14109349303SJacques Pienaar return "?";
14209349303SJacques Pienaar return llvm::formatv("{0}", dim).str();
14309349303SJacques Pienaar });
14409349303SJacques Pienaar llvm::errs() << "rank = " << getRank() << " dims = [";
14509349303SJacques Pienaar llvm::interleave(mapped, llvm::errs(), "x");
14609349303SJacques Pienaar llvm::errs() << "]\n";
14709349303SJacques Pienaar }
14809349303SJacques Pienaar
getValueAsShape(int index)14909349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
15009349303SJacques Pienaar Value val = operator[](index);
15109349303SJacques Pienaar if (valueToShape)
15209349303SJacques Pienaar if (ShapeAdaptor ret = valueToShape(val))
15309349303SJacques Pienaar return ret;
15409349303SJacques Pienaar
15509349303SJacques Pienaar DenseIntElementsAttr attr;
15609349303SJacques Pienaar if (!matchPattern(val, m_Constant(&attr)))
15709349303SJacques Pienaar return nullptr;
15809349303SJacques Pienaar if (attr.getType().getRank() != 1)
15909349303SJacques Pienaar return nullptr;
16009349303SJacques Pienaar return attr;
16109349303SJacques Pienaar }
16209349303SJacques Pienaar
getShape(Value val) const16309349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(Value val) const {
16409349303SJacques Pienaar if (operandShape)
16509349303SJacques Pienaar if (ShapeAdaptor ret = operandShape(val))
16609349303SJacques Pienaar return ret;
16709349303SJacques Pienaar return val.getType();
16809349303SJacques Pienaar }
16909349303SJacques Pienaar
getShape(int index) const17009349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(int index) const {
17109349303SJacques Pienaar if (index < 0 || static_cast<size_t>(index) >= size())
17209349303SJacques Pienaar return nullptr;
17309349303SJacques Pienaar return getShape(operator[](index));
17409349303SJacques Pienaar }
17509349303SJacques Pienaar
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1767ce1e7abSRiver Riddle LogicalResult mlir::detail::inferReturnTensorTypes(
1777ce1e7abSRiver Riddle function_ref<LogicalResult(
178a5b889dbSJacques Pienaar MLIRContext *, Optional<Location> location, ValueShapeRange operands,
1795eae715aSJacques Pienaar DictionaryAttr attributes, RegionRange regions,
1807ce1e7abSRiver Riddle SmallVectorImpl<ShapedTypeComponents> &retComponents)>
1817ce1e7abSRiver Riddle componentTypeFn,
1827ce1e7abSRiver Riddle MLIRContext *context, Optional<Location> location, ValueRange operands,
1835eae715aSJacques Pienaar DictionaryAttr attributes, RegionRange regions,
1847ce1e7abSRiver Riddle SmallVectorImpl<Type> &inferredReturnTypes) {
1857ce1e7abSRiver Riddle SmallVector<ShapedTypeComponents, 2> retComponents;
1867ce1e7abSRiver Riddle if (failed(componentTypeFn(context, location, operands, attributes, regions,
1877ce1e7abSRiver Riddle retComponents)))
1887ce1e7abSRiver Riddle return failure();
189e4853be2SMehdi Amini for (const auto &shapeAndType : retComponents) {
1907ce1e7abSRiver Riddle assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
19135bd4191SJacques Pienaar assert(shapeAndType.getElementType() &&
19235bd4191SJacques Pienaar "element type required to construct tensor");
1937ce1e7abSRiver Riddle if (shapeAndType.hasRank())
1947ce1e7abSRiver Riddle inferredReturnTypes.push_back(RankedTensorType::get(
1957ce1e7abSRiver Riddle shapeAndType.getDims(), shapeAndType.getElementType()));
1967ce1e7abSRiver Riddle else
1977ce1e7abSRiver Riddle inferredReturnTypes.push_back(
1987ce1e7abSRiver Riddle UnrankedTensorType::get(shapeAndType.getElementType()));
1997ce1e7abSRiver Riddle }
2007ce1e7abSRiver Riddle return success();
2017ce1e7abSRiver Riddle }
2027ce1e7abSRiver Riddle
verifyInferredResultTypes(Operation * op)2037ce1e7abSRiver Riddle LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
204c8598fa2SJacques Pienaar SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
2057ce1e7abSRiver Riddle auto retTypeFn = cast<InferTypeOpInterface>(op);
206c8598fa2SJacques Pienaar return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
207c8598fa2SJacques Pienaar op->getOperands(), op->getAttrDictionary(),
208c8598fa2SJacques Pienaar op->getRegions(), inferredReturnTypes);
2097ce1e7abSRiver Riddle }
210