1//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file contains a set of interfaces that can be used to define information 10// related to type inference. 11// 12//===----------------------------------------------------------------------===// 13 14#ifndef MLIR_INFERTYPEOPINTERFACE 15#define MLIR_INFERTYPEOPINTERFACE 16 17include "mlir/IR/OpBase.td" 18 19// OpInterface to compute the return type of an operation. The arguments match 20// those in Operation::create with the exception that the location is optional 21// (if no location is provided, then the method will not emit an error on 22// mismatch). 23def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { 24 let description = [{ 25 Interface to infer the return types for an operation that could be used 26 during op construction, verification or type inference. 27 }]; 28 let cppNamespace = "::mlir"; 29 30 let methods = [ 31 StaticInterfaceMethod< 32 /*desc=*/[{Infer the return types that an op would generate. 33 34 The method takes an optional location which, if set, will be used to 35 report errors on. The operands and attributes correspond to those with 36 which an Operation would be created (e.g., as used in Operation::create) 37 and the regions of the op. Be aware that this method is supposed to be 38 called with valid arguments, e.g., operands are verified, or it may result 39 in an undefined behavior. 40 }], 41 /*retTy=*/"::mlir::LogicalResult", 42 /*methodName=*/"inferReturnTypes", 43 /*args=*/(ins "::mlir::MLIRContext *":$context, 44 "::llvm::Optional<::mlir::Location>":$location, 45 "::mlir::ValueRange":$operands, 46 "::mlir::DictionaryAttr":$attributes, 47 "::mlir::RegionRange":$regions, 48 "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes) 49 >, 50 StaticInterfaceMethod< 51 /*desc=*/[{Refine the return types that an op would generate. 52 53 This method computes the return types as `inferReturnTypes` does but 54 additionally takes the existing result types as input. The existing 55 result types can be checked as part of inference to provide more 56 op-specific error messages as well as part of inference to merge 57 additional information, attributes, during inference. It is called during 58 verification for ops implementing this trait with default behavior 59 reporting mismatch with current and inferred types printed. 60 61 The operands and attributes correspond to those with which an Operation 62 would be created (e.g., as used in Operation::create) and the regions of 63 the op. The method takes an optional location which, if set, will be used 64 to report errors on. 65 66 The return types may be elided or specific elements be null for elements 67 that should just be returned but not verified. 68 69 Be aware that this method is supposed to be called with valid arguments, 70 e.g., operands are verified, or it may result in an undefined behavior. 71 }], 72 /*retTy=*/"::mlir::LogicalResult", 73 /*methodName=*/"refineReturnTypes", 74 /*args=*/(ins "::mlir::MLIRContext *":$context, 75 "::llvm::Optional<::mlir::Location>":$location, 76 "::mlir::ValueRange":$operands, 77 "::mlir::DictionaryAttr":$attributes, 78 "::mlir::RegionRange":$regions, 79 "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes), 80 /*methodBody=*/[{}], 81 /*defaultImplementation=*/[{ 82 llvm::SmallVector<Type, 4> inferredReturnTypes; 83 if (failed(ConcreteOp::inferReturnTypes(context, location, operands, 84 attributes, regions, 85 inferredReturnTypes))) 86 return failure(); 87 if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes, 88 returnTypes)) { 89 return emitOptionalError( 90 location, "'", ConcreteOp::getOperationName(), 91 "' op inferred type(s) ", inferredReturnTypes, 92 " are incompatible with return type(s) of operation ", 93 returnTypes); 94 } 95 return success(); 96 }] 97 >, 98 StaticInterfaceMethod< 99 /*desc=*/"Returns whether two array of types are compatible result types" 100 " for an op.", 101 /*retTy=*/"bool", 102 /*methodName=*/"isCompatibleReturnTypes", 103 /*args=*/(ins "::mlir::TypeRange":$lhs, "::mlir::TypeRange":$rhs), 104 /*methodBody=*/[{ 105 return ConcreteOp::isCompatibleReturnTypes(lhs, rhs); 106 }], 107 /*defaultImplementation=*/[{ 108 /// Returns whether two arrays are equal as strongest check for 109 /// compatibility by default. 110 return lhs == rhs; 111 }] 112 >, 113 ]; 114 115 // Inferring result types may need to access the region operations. 116 let verifyWithRegions = 1; 117 let verify = [{ 118 return detail::verifyInferredResultTypes($_op); 119 }]; 120} 121 122def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { 123 let description = [{ 124 Interface to infer the components of a ShapedType returned by an operation 125 that could be used during op construction, verification or shape inference. 126 127 The components consists of element type, shape and raw attribute. 128 }]; 129 let cppNamespace = "::mlir"; 130 131 let methods = [ 132 StaticInterfaceMethod< 133 /*desc=*/[{Infer the components of return type of shape containter. 134 135 The method takes an optional location which, if set, will be used to 136 report errors on. The operands and attributes correspond to those with 137 which an Operation would be created (e.g., as used in Operation::create) 138 and the regions of the op. 139 140 Unknown (e.g., unranked) shape and nullptrs for element type and attribute 141 may be returned by this function while returning success. E.g., partial 142 population of components is not error condition. 143 }], 144 /*retTy=*/"::mlir::LogicalResult", 145 /*methodName=*/"inferReturnTypeComponents", 146 /*args=*/(ins "::mlir::MLIRContext*":$context, 147 "::llvm::Optional<::mlir::Location>":$location, 148 "::mlir::ValueShapeRange":$operands, 149 "::mlir::DictionaryAttr":$attributes, 150 "::mlir::RegionRange":$regions, 151 "::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&": 152 $inferredReturnShapes), 153 /*methodBody=*/[{}], 154 /*defaultImplementation=*/[{ return ::mlir::failure(); }] 155 >, 156 InterfaceMethod< 157 /*desc=*/[{Reify the shape computation for the operation. 158 159 Insert operations using the given OpBuilder that computes the 160 result shape. This interface is supposed to be workable during dialect 161 conversion (e.g. convert from tensor world to buffer world), 162 where `getOperand` may be invalid. For example, some ops (e.g. 163 dynamic_reshape(input, target_shape)) may depend on their operands 164 to calculate the result shape. When the `matchAndRewrite ` method 165 of a conversion pattern is called, the operands of the op to convert 166 may have been converted into other types, which makes it invalid to 167 call the `getOperand` method of such op directly inside the 168 conversion pattern. To solve this problem, this interface follows 169 the design of the conversion pattern, that is, accepting passed in 170 operands to avoid calling `getOperand` directly inside the interface 171 implementation. 172 }], 173 /*retTy=*/"::mlir::LogicalResult", 174 /*methodName=*/"reifyReturnTypeShapes", 175 /*args=*/(ins "::mlir::OpBuilder&":$builder, 176 "::mlir::ValueRange":$operands, 177 "::llvm::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes), 178 /*methodBody=*/[{}], 179 /*defaultImplementation=*/[{ return ::mlir::failure(); }] 180 > 181 ]; 182} 183 184// Convenience class grouping together type and shaped type op interfaces for 185// ops that have tensor return types. 186class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList< 187 [ 188 // Op implements infer type op interface. 189 InferTypeOpInterface, 190 // The op will have methods implementing the ShapedType type inference 191 // interface. 192 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>, 193 // The op produces tensors and will use the ShapedType type infer interface 194 // along with knowledge that it is producing Tensors to infer the type. 195 NativeOpTrait<"InferTensorType"> 196 ]>; 197 198def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>; 199def InferTensorTypeWithReify: InferTensorTypeBase<[ 200 "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; 201 202def ReifyRankedShapedTypeOpInterface : 203 OpInterface<"ReifyRankedShapedTypeOpInterface"> { 204 let description = [{ 205 Interface to compute the shape of the result of an operation when 206 the result is a ranked shape type, i.e. `RankedTensorType` or 207 `MemRefType`. 208 }]; 209 let cppNamespace = "::mlir"; 210 211 let methods = [ 212 InterfaceMethod< 213 /*desc=*/[{ 214 Reify the shape of the result of an operation (typically in 215 terms of shape of its operands) 216 217 Insert operations using the given `OpBuilder` that computes 218 the result shape. The `reifiedReturnShapes` is expected to be 219 populated with as many vectors as the number of results of the 220 op. Each of these vectors is expected to be of size equal to 221 rank of the corresponding result. If the shape of a particular 222 result cannot be computed it must be empty. 223 }], 224 /*retTy=*/"::mlir::LogicalResult", 225 /*methodName=*/"reifyResultShapes", 226 /*args=*/(ins "::mlir::OpBuilder &":$builder, 227 "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes) 228 > 229 ]; 230} 231 232// Op has the same operand and result type. 233// TODO: Change from hard coded to utilizing type inference trait. 234def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; 235 236#endif // MLIR_INFERTYPEOPINTERFACE 237