1 //===- Operator.h - Operator class ------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_TABLEGEN_OPERATOR_H_ 14 #define MLIR_TABLEGEN_OPERATOR_H_ 15 16 #include "mlir/Support/LLVM.h" 17 #include "mlir/TableGen/Argument.h" 18 #include "mlir/TableGen/Attribute.h" 19 #include "mlir/TableGen/Builder.h" 20 #include "mlir/TableGen/Dialect.h" 21 #include "mlir/TableGen/Region.h" 22 #include "mlir/TableGen/Successor.h" 23 #include "mlir/TableGen/Trait.h" 24 #include "mlir/TableGen/Type.h" 25 #include "llvm/ADT/PointerUnion.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "llvm/Support/SMLoc.h" 30 31 namespace llvm { 32 class DefInit; 33 class Record; 34 class StringInit; 35 } // namespace llvm 36 37 namespace mlir { 38 namespace tblgen { 39 40 // Wrapper class that contains a MLIR op's information (e.g., operands, 41 // attributes) defined in TableGen and provides helper methods for 42 // accessing them. 43 class Operator { 44 public: 45 explicit Operator(const llvm::Record &def); Operator(const llvm::Record * def)46 explicit Operator(const llvm::Record *def) : Operator(*def) {} 47 48 // Returns this op's dialect name. 49 StringRef getDialectName() const; 50 51 // Returns the operation name. The name will follow the "<dialect>.<op-name>" 52 // format if its dialect name is not empty. 53 std::string getOperationName() const; 54 55 // Returns this op's C++ class name. 56 StringRef getCppClassName() const; 57 58 // Returns this op's C++ class name prefixed with namespaces. 59 std::string getQualCppClassName() const; 60 61 // Returns this op's C++ namespace. 62 StringRef getCppNamespace() const; 63 64 // Returns the name of op's adaptor C++ class. 65 std::string getAdaptorName() const; 66 67 // Check invariants (like no duplicated or conflicted names) and abort the 68 // process if any invariant is broken. 69 void assertInvariants() const; 70 71 /// A class used to represent the decorators of an operator variable, i.e. 72 /// argument or result. 73 struct VariableDecorator { 74 public: VariableDecoratorVariableDecorator75 explicit VariableDecorator(const llvm::Record *def) : def(def) {} getDefVariableDecorator76 const llvm::Record &getDef() const { return *def; } 77 78 protected: 79 // The TableGen definition of this decorator. 80 const llvm::Record *def; 81 }; 82 83 // A utility iterator over a list of variable decorators. 84 struct VariableDecoratorIterator 85 : public llvm::mapped_iterator<llvm::Init *const *, 86 VariableDecorator (*)(llvm::Init *)> { 87 /// Initializes the iterator to the specified iterator. VariableDecoratorIteratorVariableDecoratorIterator88 VariableDecoratorIterator(llvm::Init *const *it) 89 : llvm::mapped_iterator<llvm::Init *const *, 90 VariableDecorator (*)(llvm::Init *)>(it, 91 &unwrap) {} 92 static VariableDecorator unwrap(llvm::Init *init); 93 }; 94 using var_decorator_iterator = VariableDecoratorIterator; 95 using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>; 96 97 using value_iterator = NamedTypeConstraint *; 98 using const_value_iterator = const NamedTypeConstraint *; 99 using value_range = llvm::iterator_range<value_iterator>; 100 using const_value_range = llvm::iterator_range<const_value_iterator>; 101 102 // Returns true if this op has variable length operands or results. 103 bool isVariadic() const; 104 105 // Returns true if default builders should not be generated. 106 bool skipDefaultBuilders() const; 107 108 // Op result iterators. 109 const_value_iterator result_begin() const; 110 const_value_iterator result_end() const; 111 const_value_range getResults() const; 112 113 // Returns the number of results this op produces. 114 int getNumResults() const; 115 116 // Returns the op result at the given `index`. getResult(int index)117 NamedTypeConstraint &getResult(int index) { return results[index]; } getResult(int index)118 const NamedTypeConstraint &getResult(int index) const { 119 return results[index]; 120 } 121 122 // Returns the `index`-th result's type constraint. 123 TypeConstraint getResultTypeConstraint(int index) const; 124 // Returns the `index`-th result's name. 125 StringRef getResultName(int index) const; 126 // Returns the `index`-th result's decorators. 127 var_decorator_range getResultDecorators(int index) const; 128 129 // Returns the number of variable length results in this operation. 130 unsigned getNumVariableLengthResults() const; 131 132 // Op attribute iterators. 133 using attribute_iterator = const NamedAttribute *; 134 attribute_iterator attribute_begin() const; 135 attribute_iterator attribute_end() const; 136 llvm::iterator_range<attribute_iterator> getAttributes() const; 137 getNumAttributes()138 int getNumAttributes() const { return attributes.size(); } getNumNativeAttributes()139 int getNumNativeAttributes() const { return numNativeAttributes; } 140 141 // Op attribute accessors. getAttribute(int index)142 NamedAttribute &getAttribute(int index) { return attributes[index]; } getAttribute(int index)143 const NamedAttribute &getAttribute(int index) const { 144 return attributes[index]; 145 } 146 147 // Op operand iterators. 148 const_value_iterator operand_begin() const; 149 const_value_iterator operand_end() const; 150 const_value_range getOperands() const; 151 getNumOperands()152 int getNumOperands() const { return operands.size(); } getOperand(int index)153 NamedTypeConstraint &getOperand(int index) { return operands[index]; } getOperand(int index)154 const NamedTypeConstraint &getOperand(int index) const { 155 return operands[index]; 156 } 157 158 // Returns the number of variadic operands in this operation. 159 unsigned getNumVariableLengthOperands() const; 160 161 // Returns the total number of arguments. getNumArgs()162 int getNumArgs() const { return arguments.size(); } 163 164 // Returns true of the operation has a single variadic arg. 165 bool hasSingleVariadicArg() const; 166 167 // Returns true if the operation has a single variadic result. hasSingleVariadicResult()168 bool hasSingleVariadicResult() const { 169 return getNumResults() == 1 && getResult(0).isVariadic(); 170 } 171 172 // Returns true of the operation has no variadic regions. hasNoVariadicRegions()173 bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; } 174 175 using arg_iterator = const Argument *; 176 using arg_range = llvm::iterator_range<arg_iterator>; 177 178 // Op argument (attribute or operand) iterators. 179 arg_iterator arg_begin() const; 180 arg_iterator arg_end() const; 181 arg_range getArgs() const; 182 183 // Op argument (attribute or operand) accessors. 184 Argument getArg(int index) const; 185 StringRef getArgName(int index) const; 186 var_decorator_range getArgDecorators(int index) const; 187 188 // Returns the trait wrapper for the given MLIR C++ `trait`. 189 const Trait *getTrait(llvm::StringRef trait) const; 190 191 // Regions. 192 using const_region_iterator = const NamedRegion *; 193 const_region_iterator region_begin() const; 194 const_region_iterator region_end() const; 195 llvm::iterator_range<const_region_iterator> getRegions() const; 196 197 // Returns the number of regions. 198 unsigned getNumRegions() const; 199 // Returns the `index`-th region. 200 const NamedRegion &getRegion(unsigned index) const; 201 202 // Returns the number of variadic regions in this operation. 203 unsigned getNumVariadicRegions() const; 204 205 // Successors. 206 using const_successor_iterator = const NamedSuccessor *; 207 const_successor_iterator successor_begin() const; 208 const_successor_iterator successor_end() const; 209 llvm::iterator_range<const_successor_iterator> getSuccessors() const; 210 211 // Returns the number of successors. 212 unsigned getNumSuccessors() const; 213 // Returns the `index`-th successor. 214 const NamedSuccessor &getSuccessor(unsigned index) const; 215 216 // Returns the number of variadic successors in this operation. 217 unsigned getNumVariadicSuccessors() const; 218 219 // Trait. 220 using const_trait_iterator = const Trait *; 221 const_trait_iterator trait_begin() const; 222 const_trait_iterator trait_end() const; 223 llvm::iterator_range<const_trait_iterator> getTraits() const; 224 225 ArrayRef<SMLoc> getLoc() const; 226 227 // Query functions for the documentation of the operator. 228 bool hasDescription() const; 229 StringRef getDescription() const; 230 bool hasSummary() const; 231 StringRef getSummary() const; 232 233 // Query functions for the assembly format of the operator. 234 bool hasAssemblyFormat() const; 235 StringRef getAssemblyFormat() const; 236 237 // Returns this op's extra class declaration code. 238 StringRef getExtraClassDeclaration() const; 239 240 // Returns this op's extra class definition code. 241 StringRef getExtraClassDefinition() const; 242 243 // Returns the Tablegen definition this operator was constructed from. 244 // TODO: do not expose the TableGen record, this is a temporary solution to 245 // OpEmitter requiring a Record because Operator does not provide enough 246 // methods. 247 const llvm::Record &getDef() const; 248 249 // Returns the dialect of the op. getDialect()250 const Dialect &getDialect() const { return dialect; } 251 252 // Prints the contents in this operator to the given `os`. This is used for 253 // debugging purposes. 254 void print(llvm::raw_ostream &os) const; 255 256 // Return whether all the result types are known. allResultTypesKnown()257 bool allResultTypesKnown() const { return allResultsHaveKnownTypes; }; 258 259 // Pair representing either a index to an argument or a type constraint. Only 260 // one of these entries should have the non-default value. 261 struct ArgOrType { ArgOrTypeArgOrType262 explicit ArgOrType(int index) : index(index), constraint(None) {} ArgOrTypeArgOrType263 explicit ArgOrType(TypeConstraint constraint) 264 : index(None), constraint(constraint) {} isArgArgOrType265 bool isArg() const { 266 assert(constraint.has_value() ^ index.has_value()); 267 return index.has_value(); 268 } isTypeArgOrType269 bool isType() const { 270 assert(constraint.has_value() ^ index.has_value()); 271 return constraint.has_value(); 272 } 273 getArgArgOrType274 int getArg() const { return *index; } getTypeArgOrType275 TypeConstraint getType() const { return *constraint; } 276 277 private: 278 Optional<int> index; 279 Optional<TypeConstraint> constraint; 280 }; 281 282 // Return all arguments or type constraints with same type as result[index]. 283 // Requires: all result types are known. 284 ArrayRef<ArgOrType> getSameTypeAsResult(int index) const; 285 286 // Pair consisting kind of argument and index into operands or attributes. 287 struct OperandOrAttribute { 288 enum class Kind { Operand, Attribute }; OperandOrAttributeOperandOrAttribute289 OperandOrAttribute(Kind kind, int index) { 290 packed = (index << 1) | (kind == Kind::Attribute); 291 } operandOrAttributeIndexOperandOrAttribute292 int operandOrAttributeIndex() const { return (packed >> 1); } kindOperandOrAttribute293 Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; } 294 295 private: 296 int packed; 297 }; 298 299 // Returns the OperandOrAttribute corresponding to the index. 300 OperandOrAttribute getArgToOperandOrAttribute(int index) const; 301 302 // Returns the builders of this operation. getBuilders()303 ArrayRef<Builder> getBuilders() const { return builders; } 304 305 // Returns the preferred getter name for the accessor. getGetterName(StringRef name)306 std::string getGetterName(StringRef name) const { 307 return getGetterNames(name).front(); 308 } 309 310 // Returns the getter names for the accessor. 311 SmallVector<std::string, 2> getGetterNames(StringRef name) const; 312 313 // Returns the setter names for the accessor. 314 SmallVector<std::string, 2> getSetterNames(StringRef name) const; 315 316 private: 317 // Populates the vectors containing operands, attributes, results and traits. 318 void populateOpStructure(); 319 320 // Populates type inference info (mostly equality) with input a mapping from 321 // names to indices for arguments and results. 322 void populateTypeInferenceInfo( 323 const llvm::StringMap<int> &argumentsAndResultsIndex); 324 325 // The dialect of this op. 326 Dialect dialect; 327 328 // The unqualified C++ class name of the op. 329 StringRef cppClassName; 330 331 // The C++ namespace for this op. 332 StringRef cppNamespace; 333 334 // The operands of the op. 335 SmallVector<NamedTypeConstraint, 4> operands; 336 337 // The attributes of the op. Contains native attributes (corresponding to the 338 // actual stored attributed of the operation) followed by derived attributes 339 // (corresponding to dynamic properties of the operation that are computed 340 // upon request). 341 SmallVector<NamedAttribute, 4> attributes; 342 343 // The arguments of the op (operands and native attributes). 344 SmallVector<Argument, 4> arguments; 345 346 // The results of the op. 347 SmallVector<NamedTypeConstraint, 4> results; 348 349 // The successors of this op. 350 SmallVector<NamedSuccessor, 0> successors; 351 352 // The traits of the op. 353 SmallVector<Trait, 4> traits; 354 355 // The regions of this op. 356 SmallVector<NamedRegion, 1> regions; 357 358 // The argument with the same type as the result. 359 SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping; 360 361 // Map from argument to attribute or operand number. 362 SmallVector<OperandOrAttribute, 4> attrOrOperandMapping; 363 364 // The builders of this operator. 365 SmallVector<Builder> builders; 366 367 // The number of native attributes stored in the leading positions of 368 // `attributes`. 369 int numNativeAttributes; 370 371 // The TableGen definition of this op. 372 const llvm::Record &def; 373 374 // Whether the type of all results are known. 375 bool allResultsHaveKnownTypes; 376 }; 377 378 } // namespace tblgen 379 } // namespace mlir 380 381 #endif // MLIR_TABLEGEN_OPERATOR_H_ 382