17ce1e7abSRiver Riddle //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
27ce1e7abSRiver Riddle //
37ce1e7abSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47ce1e7abSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57ce1e7abSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67ce1e7abSRiver Riddle //
77ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
87ce1e7abSRiver Riddle //
97ce1e7abSRiver Riddle // This file contains the definitions of the infer op interfaces defined in
107ce1e7abSRiver Riddle // `InferTypeOpInterface.td`.
117ce1e7abSRiver Riddle //
127ce1e7abSRiver Riddle //===----------------------------------------------------------------------===//
137ce1e7abSRiver Riddle 
147ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
1509f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
1609349303SJacques Pienaar #include "mlir/IR/Matchers.h"
1709349303SJacques Pienaar #include "llvm/Support/FormatVariadic.h"
187ce1e7abSRiver Riddle 
197ce1e7abSRiver Riddle using namespace mlir;
207ce1e7abSRiver Riddle 
217ce1e7abSRiver Riddle namespace mlir {
227ce1e7abSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
237ce1e7abSRiver Riddle } // namespace mlir
247ce1e7abSRiver Riddle 
hasRank() const2509349303SJacques Pienaar bool ShapeAdaptor::hasRank() const {
2609349303SJacques Pienaar   if (val.isNull())
2709349303SJacques Pienaar     return false;
2809349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
2909349303SJacques Pienaar     return t.cast<ShapedType>().hasRank();
3009349303SJacques Pienaar   if (val.is<Attribute>())
3109349303SJacques Pienaar     return true;
3209349303SJacques Pienaar   return val.get<ShapedTypeComponents *>()->hasRank();
3309349303SJacques Pienaar }
3409349303SJacques Pienaar 
getElementType() const3509349303SJacques Pienaar Type ShapeAdaptor::getElementType() const {
3609349303SJacques Pienaar   if (val.isNull())
3709349303SJacques Pienaar     return nullptr;
3809349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
3909349303SJacques Pienaar     return t.cast<ShapedType>().getElementType();
4009349303SJacques Pienaar   if (val.is<Attribute>())
4109349303SJacques Pienaar     return nullptr;
4209349303SJacques Pienaar   return val.get<ShapedTypeComponents *>()->getElementType();
4309349303SJacques Pienaar }
4409349303SJacques Pienaar 
getDims(SmallVectorImpl<int64_t> & res) const4509349303SJacques Pienaar void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
4609349303SJacques Pienaar   assert(hasRank());
4709349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>()) {
4809349303SJacques Pienaar     ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
4909349303SJacques Pienaar     res.assign(vals.begin(), vals.end());
5009349303SJacques Pienaar   } else if (auto attr = val.dyn_cast<Attribute>()) {
5109349303SJacques Pienaar     auto dattr = attr.cast<DenseIntElementsAttr>();
5209349303SJacques Pienaar     res.clear();
5309349303SJacques Pienaar     res.reserve(dattr.size());
540cb5d7fcSRiver Riddle     for (auto it : dattr.getValues<APInt>())
5509349303SJacques Pienaar       res.push_back(it.getSExtValue());
5609349303SJacques Pienaar   } else {
5709349303SJacques Pienaar     auto vals = val.get<ShapedTypeComponents *>()->getDims();
5809349303SJacques Pienaar     res.assign(vals.begin(), vals.end());
5909349303SJacques Pienaar   }
6009349303SJacques Pienaar }
6109349303SJacques Pienaar 
getDims(ShapedTypeComponents & res) const6209349303SJacques Pienaar void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
6309349303SJacques Pienaar   assert(hasRank());
6409349303SJacques Pienaar   res.ranked = true;
6509349303SJacques Pienaar   getDims(res.dims);
6609349303SJacques Pienaar }
6709349303SJacques Pienaar 
getDimSize(int index) const6809349303SJacques Pienaar int64_t ShapeAdaptor::getDimSize(int index) const {
6909349303SJacques Pienaar   assert(hasRank());
7009349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
7109349303SJacques Pienaar     return t.cast<ShapedType>().getDimSize(index);
7209349303SJacques Pienaar   if (auto attr = val.dyn_cast<Attribute>())
7309349303SJacques Pienaar     return attr.cast<DenseIntElementsAttr>()
74ae40d625SRiver Riddle         .getValues<APInt>()[index]
7509349303SJacques Pienaar         .getSExtValue();
7609349303SJacques Pienaar   auto *stc = val.get<ShapedTypeComponents *>();
7709349303SJacques Pienaar   return stc->getDims()[index];
7809349303SJacques Pienaar }
7909349303SJacques Pienaar 
getRank() const8009349303SJacques Pienaar int64_t ShapeAdaptor::getRank() const {
8109349303SJacques Pienaar   assert(hasRank());
8209349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
8309349303SJacques Pienaar     return t.cast<ShapedType>().getRank();
8409349303SJacques Pienaar   if (auto attr = val.dyn_cast<Attribute>())
8509349303SJacques Pienaar     return attr.cast<DenseIntElementsAttr>().size();
8609349303SJacques Pienaar   return val.get<ShapedTypeComponents *>()->getDims().size();
8709349303SJacques Pienaar }
8809349303SJacques Pienaar 
hasStaticShape() const8909349303SJacques Pienaar bool ShapeAdaptor::hasStaticShape() const {
9009349303SJacques Pienaar   if (!hasRank())
9109349303SJacques Pienaar     return false;
9209349303SJacques Pienaar 
9309349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
9409349303SJacques Pienaar     return t.cast<ShapedType>().hasStaticShape();
9509349303SJacques Pienaar   if (auto attr = val.dyn_cast<Attribute>()) {
9609349303SJacques Pienaar     auto dattr = attr.cast<DenseIntElementsAttr>();
970cb5d7fcSRiver Riddle     for (auto index : dattr.getValues<APInt>())
9809349303SJacques Pienaar       if (ShapedType::isDynamic(index.getSExtValue()))
9909349303SJacques Pienaar         return false;
10009349303SJacques Pienaar     return true;
10109349303SJacques Pienaar   }
10209349303SJacques Pienaar   auto *stc = val.get<ShapedTypeComponents *>();
103*380a1b20SKazu Hirata   return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
10409349303SJacques Pienaar }
10509349303SJacques Pienaar 
getNumElements() const10609349303SJacques Pienaar int64_t ShapeAdaptor::getNumElements() const {
10709349303SJacques Pienaar   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
10809349303SJacques Pienaar 
10909349303SJacques Pienaar   if (auto t = val.dyn_cast<Type>())
11009349303SJacques Pienaar     return t.cast<ShapedType>().getNumElements();
11109349303SJacques Pienaar 
11209349303SJacques Pienaar   if (auto attr = val.dyn_cast<Attribute>()) {
11309349303SJacques Pienaar     auto dattr = attr.cast<DenseIntElementsAttr>();
11409349303SJacques Pienaar     int64_t num = 1;
1150cb5d7fcSRiver Riddle     for (auto index : dattr.getValues<APInt>()) {
11609349303SJacques Pienaar       num *= index.getZExtValue();
11709349303SJacques Pienaar       assert(num >= 0 && "integer overflow in element count computation");
11809349303SJacques Pienaar     }
11909349303SJacques Pienaar     return num;
12009349303SJacques Pienaar   }
12109349303SJacques Pienaar 
12209349303SJacques Pienaar   auto *stc = val.get<ShapedTypeComponents *>();
12309349303SJacques Pienaar   int64_t num = 1;
12409349303SJacques Pienaar   for (int64_t dim : stc->getDims()) {
12509349303SJacques Pienaar     num *= dim;
12609349303SJacques Pienaar     assert(num >= 0 && "integer overflow in element count computation");
12709349303SJacques Pienaar   }
12809349303SJacques Pienaar   return num;
12909349303SJacques Pienaar }
13009349303SJacques Pienaar 
dump() const13109349303SJacques Pienaar void ShapeAdaptor::dump() const {
13209349303SJacques Pienaar   if (!hasRank()) {
13309349303SJacques Pienaar     llvm::errs() << "<<unranked>>\n";
13409349303SJacques Pienaar     return;
13509349303SJacques Pienaar   }
13609349303SJacques Pienaar 
13709349303SJacques Pienaar   SmallVector<int64_t> dims;
13809349303SJacques Pienaar   getDims(dims);
13909349303SJacques Pienaar   auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
14009349303SJacques Pienaar     if (ShapedType::isDynamic(dim))
14109349303SJacques Pienaar       return "?";
14209349303SJacques Pienaar     return llvm::formatv("{0}", dim).str();
14309349303SJacques Pienaar   });
14409349303SJacques Pienaar   llvm::errs() << "rank = " << getRank() << " dims = [";
14509349303SJacques Pienaar   llvm::interleave(mapped, llvm::errs(), "x");
14609349303SJacques Pienaar   llvm::errs() << "]\n";
14709349303SJacques Pienaar }
14809349303SJacques Pienaar 
getValueAsShape(int index)14909349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
15009349303SJacques Pienaar   Value val = operator[](index);
15109349303SJacques Pienaar   if (valueToShape)
15209349303SJacques Pienaar     if (ShapeAdaptor ret = valueToShape(val))
15309349303SJacques Pienaar       return ret;
15409349303SJacques Pienaar 
15509349303SJacques Pienaar   DenseIntElementsAttr attr;
15609349303SJacques Pienaar   if (!matchPattern(val, m_Constant(&attr)))
15709349303SJacques Pienaar     return nullptr;
15809349303SJacques Pienaar   if (attr.getType().getRank() != 1)
15909349303SJacques Pienaar     return nullptr;
16009349303SJacques Pienaar   return attr;
16109349303SJacques Pienaar }
16209349303SJacques Pienaar 
getShape(Value val) const16309349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(Value val) const {
16409349303SJacques Pienaar   if (operandShape)
16509349303SJacques Pienaar     if (ShapeAdaptor ret = operandShape(val))
16609349303SJacques Pienaar       return ret;
16709349303SJacques Pienaar   return val.getType();
16809349303SJacques Pienaar }
16909349303SJacques Pienaar 
getShape(int index) const17009349303SJacques Pienaar ShapeAdaptor ValueShapeRange::getShape(int index) const {
17109349303SJacques Pienaar   if (index < 0 || static_cast<size_t>(index) >= size())
17209349303SJacques Pienaar     return nullptr;
17309349303SJacques Pienaar   return getShape(operator[](index));
17409349303SJacques Pienaar }
17509349303SJacques Pienaar 
inferReturnTensorTypes(function_ref<LogicalResult (MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & retComponents)> componentTypeFn,MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1767ce1e7abSRiver Riddle LogicalResult mlir::detail::inferReturnTensorTypes(
1777ce1e7abSRiver Riddle     function_ref<LogicalResult(
178a5b889dbSJacques Pienaar         MLIRContext *, Optional<Location> location, ValueShapeRange operands,
1795eae715aSJacques Pienaar         DictionaryAttr attributes, RegionRange regions,
1807ce1e7abSRiver Riddle         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
1817ce1e7abSRiver Riddle         componentTypeFn,
1827ce1e7abSRiver Riddle     MLIRContext *context, Optional<Location> location, ValueRange operands,
1835eae715aSJacques Pienaar     DictionaryAttr attributes, RegionRange regions,
1847ce1e7abSRiver Riddle     SmallVectorImpl<Type> &inferredReturnTypes) {
1857ce1e7abSRiver Riddle   SmallVector<ShapedTypeComponents, 2> retComponents;
1867ce1e7abSRiver Riddle   if (failed(componentTypeFn(context, location, operands, attributes, regions,
1877ce1e7abSRiver Riddle                              retComponents)))
1887ce1e7abSRiver Riddle     return failure();
189e4853be2SMehdi Amini   for (const auto &shapeAndType : retComponents) {
1907ce1e7abSRiver Riddle     assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
19135bd4191SJacques Pienaar     assert(shapeAndType.getElementType() &&
19235bd4191SJacques Pienaar            "element type required to construct tensor");
1937ce1e7abSRiver Riddle     if (shapeAndType.hasRank())
1947ce1e7abSRiver Riddle       inferredReturnTypes.push_back(RankedTensorType::get(
1957ce1e7abSRiver Riddle           shapeAndType.getDims(), shapeAndType.getElementType()));
1967ce1e7abSRiver Riddle     else
1977ce1e7abSRiver Riddle       inferredReturnTypes.push_back(
1987ce1e7abSRiver Riddle           UnrankedTensorType::get(shapeAndType.getElementType()));
1997ce1e7abSRiver Riddle   }
2007ce1e7abSRiver Riddle   return success();
2017ce1e7abSRiver Riddle }
2027ce1e7abSRiver Riddle 
verifyInferredResultTypes(Operation * op)2037ce1e7abSRiver Riddle LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
204c8598fa2SJacques Pienaar   SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
2057ce1e7abSRiver Riddle   auto retTypeFn = cast<InferTypeOpInterface>(op);
206c8598fa2SJacques Pienaar   return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
207c8598fa2SJacques Pienaar                                      op->getOperands(), op->getAttrDictionary(),
208c8598fa2SJacques Pienaar                                      op->getRegions(), inferredReturnTypes);
2097ce1e7abSRiver Riddle }
210