//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains the definitions of the infer op interfaces defined in // `InferTypeOpInterface.td`. // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; namespace mlir { #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" } // namespace mlir LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; if (failed(componentTypeFn(context, location, operands, attributes, regions, retComponents))) return failure(); for (auto shapeAndType : retComponents) { assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); if (shapeAndType.hasRank()) inferredReturnTypes.push_back(RankedTensorType::get( shapeAndType.getDims(), shapeAndType.getElementType())); else inferredReturnTypes.push_back( UnrankedTensorType::get(shapeAndType.getElementType())); } return success(); } LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { SmallVector inferredReturnTypes; auto retTypeFn = cast(op); if (failed(retTypeFn.inferReturnTypes( op->getContext(), op->getLoc(), op->getOperands(), op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) return failure(); if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, op->getResultTypes())) return op->emitOpError( "inferred type incompatible with return type of operation"); return success(); }