1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the definitions of the infer op interfaces defined in
10 // `InferTypeOpInterface.td`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 
19 namespace mlir {
20 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
21 } // namespace mlir
22 
23 LogicalResult mlir::detail::inferReturnTensorTypes(
24     function_ref<LogicalResult(
25         MLIRContext *, Optional<Location> location, ValueRange operands,
26         DictionaryAttr attributes, RegionRange regions,
27         SmallVectorImpl<ShapedTypeComponents> &retComponents)>
28         componentTypeFn,
29     MLIRContext *context, Optional<Location> location, ValueRange operands,
30     DictionaryAttr attributes, RegionRange regions,
31     SmallVectorImpl<Type> &inferredReturnTypes) {
32   SmallVector<ShapedTypeComponents, 2> retComponents;
33   if (failed(componentTypeFn(context, location, operands, attributes, regions,
34                              retComponents)))
35     return failure();
36   for (auto shapeAndType : retComponents) {
37     assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
38     if (shapeAndType.hasRank())
39       inferredReturnTypes.push_back(RankedTensorType::get(
40           shapeAndType.getDims(), shapeAndType.getElementType()));
41     else
42       inferredReturnTypes.push_back(
43           UnrankedTensorType::get(shapeAndType.getElementType()));
44   }
45   return success();
46 }
47 
48 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
49   SmallVector<Type, 4> inferredReturnTypes;
50   auto retTypeFn = cast<InferTypeOpInterface>(op);
51   if (failed(retTypeFn.inferReturnTypes(
52           op->getContext(), op->getLoc(), op->getOperands(),
53           op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
54     return failure();
55   if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
56                                          op->getResultTypes()))
57     return op->emitOpError("inferred type(s) ")
58            << inferredReturnTypes
59            << " are incompatible with return type(s) of operation "
60            << op->getResultTypes();
61   return success();
62 }
63