//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"

using namespace mlir;

namespace mlir {
#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
} // namespace mlir

LogicalResult mlir::detail::inferReturnTensorTypes(
    function_ref<LogicalResult(
        MLIRContext *, Optional<Location> location, ValueRange operands,
        DictionaryAttr attributes, RegionRange regions,
        SmallVectorImpl<ShapedTypeComponents> &retComponents)>
        componentTypeFn,
    MLIRContext *context, Optional<Location> location, ValueRange operands,
    DictionaryAttr attributes, RegionRange regions,
    SmallVectorImpl<Type> &inferredReturnTypes) {
  SmallVector<ShapedTypeComponents, 2> retComponents;
  if (failed(componentTypeFn(context, location, operands, attributes, regions,
                             retComponents)))
    return failure();
  for (auto shapeAndType : retComponents) {
    assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
    if (shapeAndType.hasRank())
      inferredReturnTypes.push_back(RankedTensorType::get(
          shapeAndType.getDims(), shapeAndType.getElementType()));
    else
      inferredReturnTypes.push_back(
          UnrankedTensorType::get(shapeAndType.getElementType()));
  }
  return success();
}

LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
  SmallVector<Type, 4> inferredReturnTypes;
  auto retTypeFn = cast<InferTypeOpInterface>(op);
  if (failed(retTypeFn.inferReturnTypes(
          op->getContext(), op->getLoc(), op->getOperands(),
          op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
    return failure();
  if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
                                         op->getResultTypes()))
    return op->emitOpError("inferred type(s) ")
           << inferredReturnTypes
           << " are incompatible with return type(s) of operation "
           << op->getResultTypes();
  return success();
}
