1 //===- Operator.h - Operator class ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TABLEGEN_OPERATOR_H_
14 #define MLIR_TABLEGEN_OPERATOR_H_
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Builder.h"
20 #include "mlir/TableGen/Dialect.h"
21 #include "mlir/TableGen/Region.h"
22 #include "mlir/TableGen/Successor.h"
23 #include "mlir/TableGen/Trait.h"
24 #include "mlir/TableGen/Type.h"
25 #include "llvm/ADT/PointerUnion.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/SMLoc.h"
30 
31 namespace llvm {
32 class DefInit;
33 class Record;
34 class StringInit;
35 } // namespace llvm
36 
37 namespace mlir {
38 namespace tblgen {
39 
40 // Wrapper class that contains a MLIR op's information (e.g., operands,
41 // attributes) defined in TableGen and provides helper methods for
42 // accessing them.
43 class Operator {
44 public:
45   explicit Operator(const llvm::Record &def);
Operator(const llvm::Record * def)46   explicit Operator(const llvm::Record *def) : Operator(*def) {}
47 
48   // Returns this op's dialect name.
49   StringRef getDialectName() const;
50 
51   // Returns the operation name. The name will follow the "<dialect>.<op-name>"
52   // format if its dialect name is not empty.
53   std::string getOperationName() const;
54 
55   // Returns this op's C++ class name.
56   StringRef getCppClassName() const;
57 
58   // Returns this op's C++ class name prefixed with namespaces.
59   std::string getQualCppClassName() const;
60 
61   // Returns this op's C++ namespace.
62   StringRef getCppNamespace() const;
63 
64   // Returns the name of op's adaptor C++ class.
65   std::string getAdaptorName() const;
66 
67   // Check invariants (like no duplicated or conflicted names) and abort the
68   // process if any invariant is broken.
69   void assertInvariants() const;
70 
71   /// A class used to represent the decorators of an operator variable, i.e.
72   /// argument or result.
73   struct VariableDecorator {
74   public:
VariableDecoratorVariableDecorator75     explicit VariableDecorator(const llvm::Record *def) : def(def) {}
getDefVariableDecorator76     const llvm::Record &getDef() const { return *def; }
77 
78   protected:
79     // The TableGen definition of this decorator.
80     const llvm::Record *def;
81   };
82 
83   // A utility iterator over a list of variable decorators.
84   struct VariableDecoratorIterator
85       : public llvm::mapped_iterator<llvm::Init *const *,
86                                      VariableDecorator (*)(llvm::Init *)> {
87     /// Initializes the iterator to the specified iterator.
VariableDecoratorIteratorVariableDecoratorIterator88     VariableDecoratorIterator(llvm::Init *const *it)
89         : llvm::mapped_iterator<llvm::Init *const *,
90                                 VariableDecorator (*)(llvm::Init *)>(it,
91                                                                      &unwrap) {}
92     static VariableDecorator unwrap(llvm::Init *init);
93   };
94   using var_decorator_iterator = VariableDecoratorIterator;
95   using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
96 
97   using value_iterator = NamedTypeConstraint *;
98   using const_value_iterator = const NamedTypeConstraint *;
99   using value_range = llvm::iterator_range<value_iterator>;
100   using const_value_range = llvm::iterator_range<const_value_iterator>;
101 
102   // Returns true if this op has variable length operands or results.
103   bool isVariadic() const;
104 
105   // Returns true if default builders should not be generated.
106   bool skipDefaultBuilders() const;
107 
108   // Op result iterators.
109   const_value_iterator result_begin() const;
110   const_value_iterator result_end() const;
111   const_value_range getResults() const;
112 
113   // Returns the number of results this op produces.
114   int getNumResults() const;
115 
116   // Returns the op result at the given `index`.
getResult(int index)117   NamedTypeConstraint &getResult(int index) { return results[index]; }
getResult(int index)118   const NamedTypeConstraint &getResult(int index) const {
119     return results[index];
120   }
121 
122   // Returns the `index`-th result's type constraint.
123   TypeConstraint getResultTypeConstraint(int index) const;
124   // Returns the `index`-th result's name.
125   StringRef getResultName(int index) const;
126   // Returns the `index`-th result's decorators.
127   var_decorator_range getResultDecorators(int index) const;
128 
129   // Returns the number of variable length results in this operation.
130   unsigned getNumVariableLengthResults() const;
131 
132   // Op attribute iterators.
133   using attribute_iterator = const NamedAttribute *;
134   attribute_iterator attribute_begin() const;
135   attribute_iterator attribute_end() const;
136   llvm::iterator_range<attribute_iterator> getAttributes() const;
137 
getNumAttributes()138   int getNumAttributes() const { return attributes.size(); }
getNumNativeAttributes()139   int getNumNativeAttributes() const { return numNativeAttributes; }
140 
141   // Op attribute accessors.
getAttribute(int index)142   NamedAttribute &getAttribute(int index) { return attributes[index]; }
getAttribute(int index)143   const NamedAttribute &getAttribute(int index) const {
144     return attributes[index];
145   }
146 
147   // Op operand iterators.
148   const_value_iterator operand_begin() const;
149   const_value_iterator operand_end() const;
150   const_value_range getOperands() const;
151 
getNumOperands()152   int getNumOperands() const { return operands.size(); }
getOperand(int index)153   NamedTypeConstraint &getOperand(int index) { return operands[index]; }
getOperand(int index)154   const NamedTypeConstraint &getOperand(int index) const {
155     return operands[index];
156   }
157 
158   // Returns the number of variadic operands in this operation.
159   unsigned getNumVariableLengthOperands() const;
160 
161   // Returns the total number of arguments.
getNumArgs()162   int getNumArgs() const { return arguments.size(); }
163 
164   // Returns true of the operation has a single variadic arg.
165   bool hasSingleVariadicArg() const;
166 
167   // Returns true if the operation has a single variadic result.
hasSingleVariadicResult()168   bool hasSingleVariadicResult() const {
169     return getNumResults() == 1 && getResult(0).isVariadic();
170   }
171 
172   // Returns true of the operation has no variadic regions.
hasNoVariadicRegions()173   bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
174 
175   using arg_iterator = const Argument *;
176   using arg_range = llvm::iterator_range<arg_iterator>;
177 
178   // Op argument (attribute or operand) iterators.
179   arg_iterator arg_begin() const;
180   arg_iterator arg_end() const;
181   arg_range getArgs() const;
182 
183   // Op argument (attribute or operand) accessors.
184   Argument getArg(int index) const;
185   StringRef getArgName(int index) const;
186   var_decorator_range getArgDecorators(int index) const;
187 
188   // Returns the trait wrapper for the given MLIR C++ `trait`.
189   const Trait *getTrait(llvm::StringRef trait) const;
190 
191   // Regions.
192   using const_region_iterator = const NamedRegion *;
193   const_region_iterator region_begin() const;
194   const_region_iterator region_end() const;
195   llvm::iterator_range<const_region_iterator> getRegions() const;
196 
197   // Returns the number of regions.
198   unsigned getNumRegions() const;
199   // Returns the `index`-th region.
200   const NamedRegion &getRegion(unsigned index) const;
201 
202   // Returns the number of variadic regions in this operation.
203   unsigned getNumVariadicRegions() const;
204 
205   // Successors.
206   using const_successor_iterator = const NamedSuccessor *;
207   const_successor_iterator successor_begin() const;
208   const_successor_iterator successor_end() const;
209   llvm::iterator_range<const_successor_iterator> getSuccessors() const;
210 
211   // Returns the number of successors.
212   unsigned getNumSuccessors() const;
213   // Returns the `index`-th successor.
214   const NamedSuccessor &getSuccessor(unsigned index) const;
215 
216   // Returns the number of variadic successors in this operation.
217   unsigned getNumVariadicSuccessors() const;
218 
219   // Trait.
220   using const_trait_iterator = const Trait *;
221   const_trait_iterator trait_begin() const;
222   const_trait_iterator trait_end() const;
223   llvm::iterator_range<const_trait_iterator> getTraits() const;
224 
225   ArrayRef<SMLoc> getLoc() const;
226 
227   // Query functions for the documentation of the operator.
228   bool hasDescription() const;
229   StringRef getDescription() const;
230   bool hasSummary() const;
231   StringRef getSummary() const;
232 
233   // Query functions for the assembly format of the operator.
234   bool hasAssemblyFormat() const;
235   StringRef getAssemblyFormat() const;
236 
237   // Returns this op's extra class declaration code.
238   StringRef getExtraClassDeclaration() const;
239 
240   // Returns this op's extra class definition code.
241   StringRef getExtraClassDefinition() const;
242 
243   // Returns the Tablegen definition this operator was constructed from.
244   // TODO: do not expose the TableGen record, this is a temporary solution to
245   // OpEmitter requiring a Record because Operator does not provide enough
246   // methods.
247   const llvm::Record &getDef() const;
248 
249   // Returns the dialect of the op.
getDialect()250   const Dialect &getDialect() const { return dialect; }
251 
252   // Prints the contents in this operator to the given `os`. This is used for
253   // debugging purposes.
254   void print(llvm::raw_ostream &os) const;
255 
256   // Return whether all the result types are known.
allResultTypesKnown()257   bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
258 
259   // Pair representing either a index to an argument or a type constraint. Only
260   // one of these entries should have the non-default value.
261   struct ArgOrType {
ArgOrTypeArgOrType262     explicit ArgOrType(int index) : index(index), constraint(None) {}
ArgOrTypeArgOrType263     explicit ArgOrType(TypeConstraint constraint)
264         : index(None), constraint(constraint) {}
isArgArgOrType265     bool isArg() const {
266       assert(constraint.has_value() ^ index.has_value());
267       return index.has_value();
268     }
isTypeArgOrType269     bool isType() const {
270       assert(constraint.has_value() ^ index.has_value());
271       return constraint.has_value();
272     }
273 
getArgArgOrType274     int getArg() const { return *index; }
getTypeArgOrType275     TypeConstraint getType() const { return *constraint; }
276 
277   private:
278     Optional<int> index;
279     Optional<TypeConstraint> constraint;
280   };
281 
282   // Return all arguments or type constraints with same type as result[index].
283   // Requires: all result types are known.
284   ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
285 
286   // Pair consisting kind of argument and index into operands or attributes.
287   struct OperandOrAttribute {
288     enum class Kind { Operand, Attribute };
OperandOrAttributeOperandOrAttribute289     OperandOrAttribute(Kind kind, int index) {
290       packed = (index << 1) | (kind == Kind::Attribute);
291     }
operandOrAttributeIndexOperandOrAttribute292     int operandOrAttributeIndex() const { return (packed >> 1); }
kindOperandOrAttribute293     Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
294 
295   private:
296     int packed;
297   };
298 
299   // Returns the OperandOrAttribute corresponding to the index.
300   OperandOrAttribute getArgToOperandOrAttribute(int index) const;
301 
302   // Returns the builders of this operation.
getBuilders()303   ArrayRef<Builder> getBuilders() const { return builders; }
304 
305   // Returns the preferred getter name for the accessor.
getGetterName(StringRef name)306   std::string getGetterName(StringRef name) const {
307     return getGetterNames(name).front();
308   }
309 
310   // Returns the getter names for the accessor.
311   SmallVector<std::string, 2> getGetterNames(StringRef name) const;
312 
313   // Returns the setter names for the accessor.
314   SmallVector<std::string, 2> getSetterNames(StringRef name) const;
315 
316 private:
317   // Populates the vectors containing operands, attributes, results and traits.
318   void populateOpStructure();
319 
320   // Populates type inference info (mostly equality) with input a mapping from
321   // names to indices for arguments and results.
322   void populateTypeInferenceInfo(
323       const llvm::StringMap<int> &argumentsAndResultsIndex);
324 
325   // The dialect of this op.
326   Dialect dialect;
327 
328   // The unqualified C++ class name of the op.
329   StringRef cppClassName;
330 
331   // The C++ namespace for this op.
332   StringRef cppNamespace;
333 
334   // The operands of the op.
335   SmallVector<NamedTypeConstraint, 4> operands;
336 
337   // The attributes of the op.  Contains native attributes (corresponding to the
338   // actual stored attributed of the operation) followed by derived attributes
339   // (corresponding to dynamic properties of the operation that are computed
340   // upon request).
341   SmallVector<NamedAttribute, 4> attributes;
342 
343   // The arguments of the op (operands and native attributes).
344   SmallVector<Argument, 4> arguments;
345 
346   // The results of the op.
347   SmallVector<NamedTypeConstraint, 4> results;
348 
349   // The successors of this op.
350   SmallVector<NamedSuccessor, 0> successors;
351 
352   // The traits of the op.
353   SmallVector<Trait, 4> traits;
354 
355   // The regions of this op.
356   SmallVector<NamedRegion, 1> regions;
357 
358   // The argument with the same type as the result.
359   SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
360 
361   // Map from argument to attribute or operand number.
362   SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
363 
364   // The builders of this operator.
365   SmallVector<Builder> builders;
366 
367   // The number of native attributes stored in the leading positions of
368   // `attributes`.
369   int numNativeAttributes;
370 
371   // The TableGen definition of this op.
372   const llvm::Record &def;
373 
374   // Whether the type of all results are known.
375   bool allResultsHaveKnownTypes;
376 };
377 
378 } // namespace tblgen
379 } // namespace mlir
380 
381 #endif // MLIR_TABLEGEN_OPERATOR_H_
382