//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains the definitions of the infer op interfaces defined in // `InferTypeOpInterface.td`. // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; namespace mlir { #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" } // namespace mlir bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; if (auto t = val.dyn_cast()) return t.cast().hasRank(); if (val.is()) return true; return val.get()->hasRank(); } Type ShapeAdaptor::getElementType() const { if (val.isNull()) return nullptr; if (auto t = val.dyn_cast()) return t.cast().getElementType(); if (val.is()) return nullptr; return val.get()->getElementType(); } void ShapeAdaptor::getDims(SmallVectorImpl &res) const { assert(hasRank()); if (auto t = val.dyn_cast()) { ArrayRef vals = t.cast().getShape(); res.assign(vals.begin(), vals.end()); } else if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); res.clear(); res.reserve(dattr.size()); for (auto it : dattr.getValues()) res.push_back(it.getSExtValue()); } else { auto vals = val.get()->getDims(); res.assign(vals.begin(), vals.end()); } } void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { assert(hasRank()); res.ranked = true; getDims(res.dims); } int64_t ShapeAdaptor::getDimSize(int index) const { assert(hasRank()); if (auto t = val.dyn_cast()) return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() .getValues()[index] .getSExtValue(); auto *stc = val.get(); return stc->getDims()[index]; } int64_t ShapeAdaptor::getRank() const { assert(hasRank()); if (auto t = val.dyn_cast()) return t.cast().getRank(); if (auto attr = val.dyn_cast()) return attr.cast().size(); return val.get()->getDims().size(); } bool ShapeAdaptor::hasStaticShape() const { if (!hasRank()) return false; if (auto t = val.dyn_cast()) return t.cast().hasStaticShape(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); for (auto index : dattr.getValues()) if (ShapedType::isDynamic(index.getSExtValue())) return false; return true; } auto *stc = val.get(); return llvm::none_of(stc->getDims(), ShapedType::isDynamic); } int64_t ShapeAdaptor::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); if (auto t = val.dyn_cast()) return t.cast().getNumElements(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); int64_t num = 1; for (auto index : dattr.getValues()) { num *= index.getZExtValue(); assert(num >= 0 && "integer overflow in element count computation"); } return num; } auto *stc = val.get(); int64_t num = 1; for (int64_t dim : stc->getDims()) { num *= dim; assert(num >= 0 && "integer overflow in element count computation"); } return num; } void ShapeAdaptor::dump() const { if (!hasRank()) { llvm::errs() << "<>\n"; return; } SmallVector dims; getDims(dims); auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { if (ShapedType::isDynamic(dim)) return "?"; return llvm::formatv("{0}", dim).str(); }); llvm::errs() << "rank = " << getRank() << " dims = ["; llvm::interleave(mapped, llvm::errs(), "x"); llvm::errs() << "]\n"; } ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { Value val = operator[](index); if (valueToShape) if (ShapeAdaptor ret = valueToShape(val)) return ret; DenseIntElementsAttr attr; if (!matchPattern(val, m_Constant(&attr))) return nullptr; if (attr.getType().getRank() != 1) return nullptr; return attr; } ShapeAdaptor ValueShapeRange::getShape(Value val) const { if (operandShape) if (ShapeAdaptor ret = operandShape(val)) return ret; return val.getType(); } ShapeAdaptor ValueShapeRange::getShape(int index) const { if (index < 0 || static_cast(index) >= size()) return nullptr; return getShape(operator[](index)); } LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; if (failed(componentTypeFn(context, location, operands, attributes, regions, retComponents))) return failure(); for (const auto &shapeAndType : retComponents) { assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); assert(shapeAndType.getElementType() && "element type required to construct tensor"); if (shapeAndType.hasRank()) inferredReturnTypes.push_back(RankedTensorType::get( shapeAndType.getDims(), shapeAndType.getElementType())); else inferredReturnTypes.push_back( UnrankedTensorType::get(shapeAndType.getElementType())); } return success(); } LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { SmallVector inferredReturnTypes(op->getResultTypes()); auto retTypeFn = cast(op); return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(), op->getOperands(), op->getAttrDictionary(), op->getRegions(), inferredReturnTypes); }