1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 
25 #include <unordered_map>
26 
27 namespace llvm {
28 class DagInit;
29 class Init;
30 class Record;
31 } // namespace llvm
32 
33 namespace mlir {
34 namespace tblgen {
35 
36 // Mapping from TableGen Record to Operator wrapper object.
37 //
38 // We allocate each wrapper object in heap to make sure the pointer to it is
39 // valid throughout the lifetime of this map. This is important because this map
40 // is shared among multiple patterns to avoid creating the wrapper object for
41 // the same op again and again. But this map will continuously grow.
42 using RecordOperatorMap =
43     DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
44 
45 class Pattern;
46 
47 // Wrapper class providing helper methods for accessing TableGen DAG leaves
48 // used inside Patterns. This class is lightweight and designed to be used like
49 // values.
50 //
51 // A TableGen DAG construct is of the syntax
52 //   `(operator, arg0, arg1, ...)`.
53 //
54 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
55 // for handy helper methods. It only works on `arg*`s that are not nested DAG
56 // constructs.
57 class DagLeaf {
58 public:
DagLeaf(const llvm::Init * def)59   explicit DagLeaf(const llvm::Init *def) : def(def) {}
60 
61   // Returns true if this DAG leaf is not specified in the pattern. That is, it
62   // places no further constraints/transforms and just carries over the original
63   // value.
64   bool isUnspecified() const;
65 
66   // Returns true if this DAG leaf is matching an operand. That is, it specifies
67   // a type constraint.
68   bool isOperandMatcher() const;
69 
70   // Returns true if this DAG leaf is matching an attribute. That is, it
71   // specifies an attribute constraint.
72   bool isAttrMatcher() const;
73 
74   // Returns true if this DAG leaf is wrapping native code call.
75   bool isNativeCodeCall() const;
76 
77   // Returns true if this DAG leaf is specifying a constant attribute.
78   bool isConstantAttr() const;
79 
80   // Returns true if this DAG leaf is specifying an enum attribute case.
81   bool isEnumAttrCase() const;
82 
83   // Returns true if this DAG leaf is specifying a string attribute.
84   bool isStringAttr() const;
85 
86   // Returns this DAG leaf as a constraint. Asserts if fails.
87   Constraint getAsConstraint() const;
88 
89   // Returns this DAG leaf as an constant attribute. Asserts if fails.
90   ConstantAttr getAsConstantAttr() const;
91 
92   // Returns this DAG leaf as an enum attribute case.
93   // Precondition: isEnumAttrCase()
94   EnumAttrCase getAsEnumAttrCase() const;
95 
96   // Returns the matching condition template inside this DAG leaf. Assumes the
97   // leaf is an operand/attribute matcher and asserts otherwise.
98   std::string getConditionTemplate() const;
99 
100   // Returns the native code call template inside this DAG leaf.
101   // Precondition: isNativeCodeCall()
102   StringRef getNativeCodeTemplate() const;
103 
104   // Returns the number of values will be returned by the native helper
105   // function.
106   // Precondition: isNativeCodeCall()
107   int getNumReturnsOfNativeCode() const;
108 
109   // Returns the string associated with the leaf.
110   // Precondition: isStringAttr()
111   std::string getStringAttr() const;
112 
113   void print(raw_ostream &os) const;
114 
115 private:
116   friend llvm::DenseMapInfo<DagLeaf>;
getAsOpaquePointer()117   const void *getAsOpaquePointer() const { return def; }
118 
119   // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
120   // also a subclass of the given `superclass`.
121   bool isSubClassOf(StringRef superclass) const;
122 
123   const llvm::Init *def;
124 };
125 
126 // Wrapper class providing helper methods for accessing TableGen DAG constructs
127 // used inside Patterns. This class is lightweight and designed to be used like
128 // values.
129 //
130 // A TableGen DAG construct is of the syntax
131 //   `(operator, arg0, arg1, ...)`.
132 //
133 // When used inside Patterns, `operator` corresponds to some dialect op, or
134 // a known list of verbs that defines special transformation actions. This
135 // `arg*` can be a nested DAG construct. This class provides getters to
136 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
137 // methods.
138 //
139 // A null DagNode contains a nullptr and converts to false implicitly.
140 class DagNode {
141 public:
DagNode(const llvm::DagInit * node)142   explicit DagNode(const llvm::DagInit *node) : node(node) {}
143 
144   // Implicit bool converter that returns true if this DagNode is not a null
145   // DagNode.
146   operator bool() const { return node != nullptr; }
147 
148   // Returns the symbol bound to this DAG node.
149   StringRef getSymbol() const;
150 
151   // Returns the operator wrapper object corresponding to the dialect op matched
152   // by this DAG. The operator wrapper will be queried from the given `mapper`
153   // and created in it if not existing.
154   Operator &getDialectOp(RecordOperatorMap *mapper) const;
155 
156   // Returns the number of operations recursively involved in the DAG tree
157   // rooted from this node.
158   int getNumOps() const;
159 
160   // Returns the number of immediate arguments to this DAG node.
161   int getNumArgs() const;
162 
163   // Returns true if the `index`-th argument is a nested DAG construct.
164   bool isNestedDagArg(unsigned index) const;
165 
166   // Gets the `index`-th argument as a nested DAG construct if possible. Returns
167   // null DagNode otherwise.
168   DagNode getArgAsNestedDag(unsigned index) const;
169 
170   // Gets the `index`-th argument as a DAG leaf.
171   DagLeaf getArgAsLeaf(unsigned index) const;
172 
173   // Returns the specified name of the `index`-th argument.
174   StringRef getArgName(unsigned index) const;
175 
176   // Returns true if this DAG construct means to replace with an existing SSA
177   // value.
178   bool isReplaceWithValue() const;
179 
180   // Returns whether this DAG represents the location of an op creation.
181   bool isLocationDirective() const;
182 
183   // Returns whether this DAG is a return type specifier.
184   bool isReturnTypeDirective() const;
185 
186   // Returns true if this DAG node is wrapping native code call.
187   bool isNativeCodeCall() const;
188 
189   // Returns whether this DAG is an `either` specifier.
190   bool isEither() const;
191 
192   // Returns true if this DAG node is an operation.
193   bool isOperation() const;
194 
195   // Returns the native code call template inside this DAG node.
196   // Precondition: isNativeCodeCall()
197   StringRef getNativeCodeTemplate() const;
198 
199   // Returns the number of values will be returned by the native helper
200   // function.
201   // Precondition: isNativeCodeCall()
202   int getNumReturnsOfNativeCode() const;
203 
204   void print(raw_ostream &os) const;
205 
206 private:
207   friend class SymbolInfoMap;
208   friend llvm::DenseMapInfo<DagNode>;
getAsOpaquePointer()209   const void *getAsOpaquePointer() const { return node; }
210 
211   const llvm::DagInit *node; // nullptr means null DagNode
212 };
213 
214 // A class for maintaining information for symbols bound in patterns and
215 // provides methods for resolving them according to specific use cases.
216 //
217 // Symbols can be bound to
218 //
219 // * Op arguments and op results in the source pattern and
220 // * Op results in result patterns.
221 //
222 // Symbols can be referenced in result patterns and additional constraints to
223 // the pattern.
224 //
225 // For example, in
226 //
227 // ```
228 // def : Pattern<
229 //     (SrcOp:$results1 $arg0, %arg1),
230 //     [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
231 // ```
232 //
233 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
234 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
235 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
236 //
237 // If a symbol binds to a multi-result op and it does not have the `__N`
238 // suffix, the symbol is expanded to represent all results generated by the
239 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
240 // only the N-th *static* result as declared in ODS, and that can still
241 // corresponds to multiple *dynamic* values if the N-th *static* result is
242 // variadic.
243 //
244 // This class keeps track of such symbols and resolves them into their bound
245 // values in a suitable way.
246 class SymbolInfoMap {
247 public:
SymbolInfoMap(ArrayRef<SMLoc> loc)248   explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {}
249 
250   // Class for information regarding a symbol.
251   class SymbolInfo {
252   public:
253     // Returns a type string of a variable.
254     std::string getVarTypeStr(StringRef name) const;
255 
256     // Returns a string for defining a variable named as `name` to store the
257     // value bound by this symbol.
258     std::string getVarDecl(StringRef name) const;
259 
260     // Returns a string for defining an argument which passes the reference of
261     // the variable.
262     std::string getArgDecl(StringRef name) const;
263 
264     // Returns a variable name for the symbol named as `name`.
265     std::string getVarName(StringRef name) const;
266 
267   private:
268     // Allow SymbolInfoMap to access private methods.
269     friend class SymbolInfoMap;
270 
271     // DagNode and DagLeaf are accessed by value which means it can't be used as
272     // identifier here. Use an opaque pointer type instead.
273     using DagAndConstant = std::pair<const void *, int>;
274 
275     // What kind of entity this symbol represents:
276     // * Attr: op attribute
277     // * Operand: op operand
278     // * Result: op result
279     // * Value: a value not attached to an op (e.g., from NativeCodeCall)
280     // * MultipleValues: a pack of values not attached to an op (e.g., from
281     //   NativeCodeCall). This kind supports indexing.
282     enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
283 
284     // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
285     // and `Operand` so should be llvm::None for `Result` and `Value` kind.
286     SymbolInfo(const Operator *op, Kind kind,
287                Optional<DagAndConstant> dagAndConstant);
288 
289     // Static methods for creating SymbolInfo.
getAttr(const Operator * op,int index)290     static SymbolInfo getAttr(const Operator *op, int index) {
291       return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
292     }
getAttr()293     static SymbolInfo getAttr() {
294       return SymbolInfo(nullptr, Kind::Attr, llvm::None);
295     }
getOperand(DagNode node,const Operator * op,int index)296     static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
297       return SymbolInfo(op, Kind::Operand,
298                         DagAndConstant(node.getAsOpaquePointer(), index));
299     }
getResult(const Operator * op)300     static SymbolInfo getResult(const Operator *op) {
301       return SymbolInfo(op, Kind::Result, llvm::None);
302     }
getValue()303     static SymbolInfo getValue() {
304       return SymbolInfo(nullptr, Kind::Value, llvm::None);
305     }
getMultipleValues(int numValues)306     static SymbolInfo getMultipleValues(int numValues) {
307       return SymbolInfo(nullptr, Kind::MultipleValues,
308                         DagAndConstant(nullptr, numValues));
309     }
310 
311     // Returns the number of static values this symbol corresponds to.
312     // A static value is an operand/result declared in ODS. Normally a symbol
313     // only represents one static value, but symbols bound to op results can
314     // represent more than one if the op is a multi-result op.
315     int getStaticValueCount() const;
316 
317     // Returns a string containing the C++ expression for referencing this
318     // symbol as a value (if this symbol represents one static value) or a value
319     // range (if this symbol represents multiple static values). `name` is the
320     // name of the C++ variable that this symbol bounds to. `index` should only
321     // be used for indexing results.  `fmt` is used to format each value.
322     // `separator` is used to separate values if this is a value range.
323     std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
324                                     const char *separator) const;
325 
326     // Returns a string containing the C++ expression for referencing this
327     // symbol as a value range regardless of how many static values this symbol
328     // represents. `name` is the name of the C++ variable that this symbol
329     // bounds to. `index` should only be used for indexing results. `fmt` is
330     // used to format each value. `separator` is used to separate values in the
331     // range.
332     std::string getAllRangeUse(StringRef name, int index, const char *fmt,
333                                const char *separator) const;
334 
335     // The argument index (for `Attr` and `Operand` only)
getArgIndex()336     int getArgIndex() const { return (*dagAndConstant).second; }
337 
338     // The number of values in the MultipleValue
getSize()339     int getSize() const { return (*dagAndConstant).second; }
340 
341     const Operator *op; // The op where the bound entity belongs
342     Kind kind;          // The kind of the bound entity
343 
344     // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
345     // the size of MultipleValue symbol). Note that operands may be bound to the
346     // same symbol, use the DagNode and index to distinguish them. For `Attr`
347     // and MultipleValue, the Dag part will be nullptr.
348     Optional<DagAndConstant> dagAndConstant;
349 
350     // Alternative name for the symbol. It is used in case the name
351     // is not unique. Applicable for `Operand` only.
352     Optional<std::string> alternativeName;
353   };
354 
355   using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
356 
357   // Iterators for accessing all symbols.
358   using iterator = BaseT::iterator;
begin()359   iterator begin() { return symbolInfoMap.begin(); }
end()360   iterator end() { return symbolInfoMap.end(); }
361 
362   // Const iterators for accessing all symbols.
363   using const_iterator = BaseT::const_iterator;
begin()364   const_iterator begin() const { return symbolInfoMap.begin(); }
end()365   const_iterator end() const { return symbolInfoMap.end(); }
366 
367   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
368   // Returns false if `symbol` is already bound and symbols are not operands.
369   bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
370                       int argIndex);
371 
372   // Binds the given `symbol` to the results the given `op`. Returns false if
373   // `symbol` is already bound.
374   bool bindOpResult(StringRef symbol, const Operator &op);
375 
376   // A helper function for dispatching target value binding functions.
377   bool bindValues(StringRef symbol, int numValues = 1);
378 
379   // Registers the given `symbol` as bound to the Value(s). Returns false if
380   // `symbol` is already bound.
381   bool bindValue(StringRef symbol);
382 
383   // Registers the given `symbol` as bound to a MultipleValue. Return false if
384   // `symbol` is already bound.
385   bool bindMultipleValues(StringRef symbol, int numValues);
386 
387   // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
388   // is already bound.
389   bool bindAttr(StringRef symbol);
390 
391   // Returns true if the given `symbol` is bound.
392   bool contains(StringRef symbol) const;
393 
394   // Returns an iterator to the information of the given symbol named as `key`.
395   const_iterator find(StringRef key) const;
396 
397   // Returns an iterator to the information of the given symbol named as `key`,
398   // with index `argIndex` for operator `op`.
399   const_iterator findBoundSymbol(StringRef key, DagNode node,
400                                  const Operator &op, int argIndex) const;
401   const_iterator findBoundSymbol(StringRef key,
402                                  const SymbolInfo &symbolInfo) const;
403 
404   // Returns the bounds of a range that includes all the elements which
405   // bind to the `key`.
406   std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
407 
408   // Returns number of times symbol named as `key` was used.
409   int count(StringRef key) const;
410 
411   // Returns the number of static values of the given `symbol` corresponds to.
412   // A static value is an operand/result declared in ODS. Normally a symbol only
413   // represents one static value, but symbols bound to op results can represent
414   // more than one if the op is a multi-result op.
415   int getStaticValueCount(StringRef symbol) const;
416 
417   // Returns a string containing the C++ expression for referencing this
418   // symbol as a value (if this symbol represents one static value) or a value
419   // range (if this symbol represents multiple static values). `fmt` is used to
420   // format each value. `separator` is used to separate values if `symbol`
421   // represents a value range.
422   std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
423                                   const char *separator = ", ") const;
424 
425   // Returns a string containing the C++ expression for referencing this
426   // symbol as a value range regardless of how many static values this symbol
427   // represents. `fmt` is used to format each value. `separator` is used to
428   // separate values in the range.
429   std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
430                              const char *separator = ", ") const;
431 
432   // Assign alternative unique names to Operands that have equal names.
433   void assignUniqueAlternativeNames();
434 
435   // Splits the given `symbol` into a value pack name and an index. Returns the
436   // value pack name and writes the index to `index` on success. Returns
437   // `symbol` itself if it does not contain an index.
438   //
439   // We can use `name__N` to access the `N`-th value in the value pack bound to
440   // `name`. `name` is typically the results of an multi-result op.
441   static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
442 
443 private:
444   BaseT symbolInfoMap;
445 
446   // Pattern instantiation location. This is intended to be used as parameter
447   // to PrintFatalError() to report errors.
448   ArrayRef<SMLoc> loc;
449 };
450 
451 // Wrapper class providing helper methods for accessing MLIR Pattern defined
452 // in TableGen. This class should closely reflect what is defined as class
453 // `Pattern` in TableGen. This class contains maps so it is not intended to be
454 // used as values.
455 class Pattern {
456 public:
457   explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
458 
459   // Returns the source pattern to match.
460   DagNode getSourcePattern() const;
461 
462   // Returns the number of result patterns generated by applying this rewrite
463   // rule.
464   int getNumResultPatterns() const;
465 
466   // Returns the DAG tree root node of the `index`-th result pattern.
467   DagNode getResultPattern(unsigned index) const;
468 
469   // Collects all symbols bound in the source pattern into `infoMap`.
470   void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
471 
472   // Collects all symbols bound in result patterns into `infoMap`.
473   void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
474 
475   // Returns the op that the root node of the source pattern matches.
476   const Operator &getSourceRootOp();
477 
478   // Returns the operator wrapper object corresponding to the given `node`'s DAG
479   // operator.
480   Operator &getDialectOp(DagNode node);
481 
482   // Returns the constraints.
483   std::vector<AppliedConstraint> getConstraints() const;
484 
485   // Returns the benefit score of the pattern.
486   int getBenefit() const;
487 
488   using IdentifierLine = std::pair<StringRef, unsigned>;
489 
490   // Returns the file location of the pattern (buffer identifier + line number
491   // pair).
492   std::vector<IdentifierLine> getLocation() const;
493 
494   // Recursively collects all bound symbols inside the DAG tree rooted
495   // at `tree` and updates the given `infoMap`.
496   void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
497                            bool isSrcPattern);
498 
499 private:
500   // Helper function to verify variable binding.
501   void verifyBind(bool result, StringRef symbolName);
502 
503   // The TableGen definition of this pattern.
504   const llvm::Record &def;
505 
506   // All operators.
507   // TODO: we need a proper context manager, like MLIRContext, for managing the
508   // lifetime of shared entities.
509   RecordOperatorMap *recordOpMap;
510 };
511 
512 } // namespace tblgen
513 } // namespace mlir
514 
515 namespace llvm {
516 template <>
517 struct DenseMapInfo<mlir::tblgen::DagNode> {
518   static mlir::tblgen::DagNode getEmptyKey() {
519     return mlir::tblgen::DagNode(
520         llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey());
521   }
522   static mlir::tblgen::DagNode getTombstoneKey() {
523     return mlir::tblgen::DagNode(
524         llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey());
525   }
526   static unsigned getHashValue(mlir::tblgen::DagNode node) {
527     return llvm::hash_value(node.getAsOpaquePointer());
528   }
529   static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) {
530     return lhs.node == rhs.node;
531   }
532 };
533 
534 template <>
535 struct DenseMapInfo<mlir::tblgen::DagLeaf> {
536   static mlir::tblgen::DagLeaf getEmptyKey() {
537     return mlir::tblgen::DagLeaf(
538         llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
539   }
540   static mlir::tblgen::DagLeaf getTombstoneKey() {
541     return mlir::tblgen::DagLeaf(
542         llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
543   }
544   static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
545     return llvm::hash_value(leaf.getAsOpaquePointer());
546   }
547   static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
548     return lhs.def == rhs.def;
549   }
550 };
551 } // namespace llvm
552 
553 #endif // MLIR_TABLEGEN_PATTERN_H_
554