1//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a set of interfaces that can be used to define information
10// related to type inference.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_INFERTYPEOPINTERFACE
15#define MLIR_INFERTYPEOPINTERFACE
16
17include "mlir/IR/OpBase.td"
18
19// OpInterface to compute the return type of an operation. The arguments match
20// those in Operation::create with the exception that the location is optional
21// (if no location is provided, then the method will not emit an error on
22// mismatch).
23def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
24  let description = [{
25    Interface to infer the return types for an operation that could be used
26    during op construction, verification or type inference.
27  }];
28  let cppNamespace = "::mlir";
29
30  let methods = [
31    StaticInterfaceMethod<
32      /*desc=*/[{Infer the return types that an op would generate.
33
34      The method takes an optional location which, if set, will be used to
35      report errors on. The operands and attributes correspond to those with
36      which an Operation would be created (e.g., as used in Operation::create)
37      and the regions of the op. Be aware that this method is supposed to be
38      called with valid arguments, e.g., operands are verified, or it may result
39      in an undefined behavior.
40      }],
41      /*retTy=*/"::mlir::LogicalResult",
42      /*methodName=*/"inferReturnTypes",
43      /*args=*/(ins "::mlir::MLIRContext *":$context,
44                    "::llvm::Optional<::mlir::Location>":$location,
45                    "::mlir::ValueRange":$operands,
46                    "::mlir::DictionaryAttr":$attributes,
47                    "::mlir::RegionRange":$regions,
48                    "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
49    >,
50    StaticInterfaceMethod<
51      /*desc=*/[{Refine the return types that an op would generate.
52
53      This method computes the return types as `inferReturnTypes` does but
54      additionally takes the existing result types as input. The existing
55      result types can be checked as part of inference to provide more
56      op-specific error messages as well as part of inference to merge
57      additional information, attributes, during inference. It is called during
58      verification for ops implementing this trait with default behavior
59      reporting mismatch with current and inferred types printed.
60
61      The operands and attributes correspond to those with which an Operation
62      would be created (e.g., as used in Operation::create) and the regions of
63      the op. The method takes an optional location which, if set, will be used
64      to report errors on.
65
66      The return types may be elided or specific elements be null for elements
67      that should just be returned but not verified.
68
69      Be aware that this method is supposed to be called with valid arguments,
70      e.g., operands are verified, or it may result in an undefined behavior.
71      }],
72      /*retTy=*/"::mlir::LogicalResult",
73      /*methodName=*/"refineReturnTypes",
74      /*args=*/(ins "::mlir::MLIRContext *":$context,
75                    "::llvm::Optional<::mlir::Location>":$location,
76                    "::mlir::ValueRange":$operands,
77                    "::mlir::DictionaryAttr":$attributes,
78                    "::mlir::RegionRange":$regions,
79                    "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
80      /*methodBody=*/[{}],
81      /*defaultImplementation=*/[{
82          llvm::SmallVector<Type, 4> inferredReturnTypes;
83          if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
84                                                  attributes, regions,
85                                                  inferredReturnTypes)))
86            return failure();
87          if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
88                                                   returnTypes)) {
89            return emitOptionalError(
90                location, "'", ConcreteOp::getOperationName(),
91                "' op inferred type(s) ", inferredReturnTypes,
92                " are incompatible with return type(s) of operation ",
93                returnTypes);
94          }
95          return success();
96      }]
97    >,
98    StaticInterfaceMethod<
99      /*desc=*/"Returns whether two array of types are compatible result types"
100               " for an op.",
101      /*retTy=*/"bool",
102      /*methodName=*/"isCompatibleReturnTypes",
103      /*args=*/(ins "::mlir::TypeRange":$lhs, "::mlir::TypeRange":$rhs),
104      /*methodBody=*/[{
105        return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
106      }],
107      /*defaultImplementation=*/[{
108        /// Returns whether two arrays are equal as strongest check for
109        /// compatibility by default.
110        return lhs == rhs;
111      }]
112    >,
113  ];
114
115  // Inferring result types may need to access the region operations.
116  let verifyWithRegions = 1;
117  let verify = [{
118    return detail::verifyInferredResultTypes($_op);
119  }];
120}
121
122def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
123  let description = [{
124    Interface to infer the components of a ShapedType returned by an operation
125    that could be used during op construction, verification or shape inference.
126
127    The components consists of element type, shape and raw attribute.
128  }];
129  let cppNamespace = "::mlir";
130
131  let methods = [
132    StaticInterfaceMethod<
133      /*desc=*/[{Infer the components of return type of shape containter.
134
135      The method takes an optional location which, if set, will be used to
136      report errors on. The operands and attributes correspond to those with
137      which an Operation would be created (e.g., as used in Operation::create)
138      and the regions of the op.
139
140      Unknown (e.g., unranked) shape and nullptrs for element type and attribute
141      may be returned by this function while returning success. E.g., partial
142      population of components is not error condition.
143      }],
144      /*retTy=*/"::mlir::LogicalResult",
145      /*methodName=*/"inferReturnTypeComponents",
146      /*args=*/(ins "::mlir::MLIRContext*":$context,
147                    "::llvm::Optional<::mlir::Location>":$location,
148                    "::mlir::ValueShapeRange":$operands,
149                    "::mlir::DictionaryAttr":$attributes,
150                    "::mlir::RegionRange":$regions,
151                    "::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
152                      $inferredReturnShapes),
153      /*methodBody=*/[{}],
154      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
155    >,
156    InterfaceMethod<
157      /*desc=*/[{Reify the shape computation for the operation.
158
159      Insert operations using the given OpBuilder that computes the
160      result shape. This interface is supposed to be workable during dialect
161      conversion (e.g. convert from tensor world to buffer world),
162      where `getOperand` may be invalid. For example, some ops (e.g.
163      dynamic_reshape(input, target_shape)) may depend on their operands
164      to calculate the result shape. When the `matchAndRewrite ` method
165      of a conversion pattern is called, the operands of the op to convert
166      may have been converted into other types, which makes it invalid to
167      call the `getOperand` method of such op directly inside the
168      conversion pattern.  To solve this problem, this interface follows
169      the design of the conversion pattern, that is, accepting passed in
170      operands to avoid calling `getOperand` directly inside the interface
171      implementation.
172      }],
173      /*retTy=*/"::mlir::LogicalResult",
174      /*methodName=*/"reifyReturnTypeShapes",
175      /*args=*/(ins "::mlir::OpBuilder&":$builder,
176          "::mlir::ValueRange":$operands,
177          "::llvm::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes),
178      /*methodBody=*/[{}],
179      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
180    >
181  ];
182}
183
184// Convenience class grouping together type and shaped type op interfaces for
185// ops that have tensor return types.
186class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
187  [
188    // Op implements infer type op interface.
189    InferTypeOpInterface,
190    // The op will have methods implementing the ShapedType type inference
191    // interface.
192    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
193    // The op produces tensors and will use the ShapedType type infer interface
194    // along with knowledge that it is producing Tensors to infer the type.
195    NativeOpTrait<"InferTensorType">
196  ]>;
197
198def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
199def InferTensorTypeWithReify: InferTensorTypeBase<[
200    "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
201
202def ReifyRankedShapedTypeOpInterface :
203    OpInterface<"ReifyRankedShapedTypeOpInterface"> {
204  let description = [{
205    Interface to compute the shape of the result of an operation when
206    the result is a ranked shape type, i.e. `RankedTensorType` or
207    `MemRefType`.
208  }];
209  let cppNamespace = "::mlir";
210
211  let methods = [
212    InterfaceMethod<
213      /*desc=*/[{
214        Reify the shape of the result of an operation (typically in
215        terms of shape of its operands)
216
217        Insert operations using the given `OpBuilder` that computes
218        the result shape. The `reifiedReturnShapes` is expected to be
219        populated with as many vectors as the number of results of the
220        op. Each of these vectors is expected to be of size equal to
221        rank of the corresponding result. If the shape of a particular
222        result cannot be computed it must be empty.
223      }],
224      /*retTy=*/"::mlir::LogicalResult",
225      /*methodName=*/"reifyResultShapes",
226      /*args=*/(ins "::mlir::OpBuilder &":$builder,
227        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
228    >
229  ];
230}
231
232// Op has the same operand and result type.
233// TODO: Change from hard coded to utilizing type inference trait.
234def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
235
236#endif // MLIR_INFERTYPEOPINTERFACE
237