1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TABLEGEN_PATTERN_H_ 15 #define MLIR_TABLEGEN_PATTERN_H_ 16 17 #include "mlir/Support/LLVM.h" 18 #include "mlir/TableGen/Argument.h" 19 #include "mlir/TableGen/Operator.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/StringMap.h" 23 #include "llvm/ADT/StringSet.h" 24 25 #include <unordered_map> 26 27 namespace llvm { 28 class DagInit; 29 class Init; 30 class Record; 31 } // namespace llvm 32 33 namespace mlir { 34 namespace tblgen { 35 36 // Mapping from TableGen Record to Operator wrapper object. 37 // 38 // We allocate each wrapper object in heap to make sure the pointer to it is 39 // valid throughout the lifetime of this map. This is important because this map 40 // is shared among multiple patterns to avoid creating the wrapper object for 41 // the same op again and again. But this map will continuously grow. 42 using RecordOperatorMap = 43 DenseMap<const llvm::Record *, std::unique_ptr<Operator>>; 44 45 class Pattern; 46 47 // Wrapper class providing helper methods for accessing TableGen DAG leaves 48 // used inside Patterns. This class is lightweight and designed to be used like 49 // values. 50 // 51 // A TableGen DAG construct is of the syntax 52 // `(operator, arg0, arg1, ...)`. 53 // 54 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects 55 // for handy helper methods. It only works on `arg*`s that are not nested DAG 56 // constructs. 57 class DagLeaf { 58 public: DagLeaf(const llvm::Init * def)59 explicit DagLeaf(const llvm::Init *def) : def(def) {} 60 61 // Returns true if this DAG leaf is not specified in the pattern. That is, it 62 // places no further constraints/transforms and just carries over the original 63 // value. 64 bool isUnspecified() const; 65 66 // Returns true if this DAG leaf is matching an operand. That is, it specifies 67 // a type constraint. 68 bool isOperandMatcher() const; 69 70 // Returns true if this DAG leaf is matching an attribute. That is, it 71 // specifies an attribute constraint. 72 bool isAttrMatcher() const; 73 74 // Returns true if this DAG leaf is wrapping native code call. 75 bool isNativeCodeCall() const; 76 77 // Returns true if this DAG leaf is specifying a constant attribute. 78 bool isConstantAttr() const; 79 80 // Returns true if this DAG leaf is specifying an enum attribute case. 81 bool isEnumAttrCase() const; 82 83 // Returns true if this DAG leaf is specifying a string attribute. 84 bool isStringAttr() const; 85 86 // Returns this DAG leaf as a constraint. Asserts if fails. 87 Constraint getAsConstraint() const; 88 89 // Returns this DAG leaf as an constant attribute. Asserts if fails. 90 ConstantAttr getAsConstantAttr() const; 91 92 // Returns this DAG leaf as an enum attribute case. 93 // Precondition: isEnumAttrCase() 94 EnumAttrCase getAsEnumAttrCase() const; 95 96 // Returns the matching condition template inside this DAG leaf. Assumes the 97 // leaf is an operand/attribute matcher and asserts otherwise. 98 std::string getConditionTemplate() const; 99 100 // Returns the native code call template inside this DAG leaf. 101 // Precondition: isNativeCodeCall() 102 StringRef getNativeCodeTemplate() const; 103 104 // Returns the number of values will be returned by the native helper 105 // function. 106 // Precondition: isNativeCodeCall() 107 int getNumReturnsOfNativeCode() const; 108 109 // Returns the string associated with the leaf. 110 // Precondition: isStringAttr() 111 std::string getStringAttr() const; 112 113 void print(raw_ostream &os) const; 114 115 private: 116 friend llvm::DenseMapInfo<DagLeaf>; getAsOpaquePointer()117 const void *getAsOpaquePointer() const { return def; } 118 119 // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and 120 // also a subclass of the given `superclass`. 121 bool isSubClassOf(StringRef superclass) const; 122 123 const llvm::Init *def; 124 }; 125 126 // Wrapper class providing helper methods for accessing TableGen DAG constructs 127 // used inside Patterns. This class is lightweight and designed to be used like 128 // values. 129 // 130 // A TableGen DAG construct is of the syntax 131 // `(operator, arg0, arg1, ...)`. 132 // 133 // When used inside Patterns, `operator` corresponds to some dialect op, or 134 // a known list of verbs that defines special transformation actions. This 135 // `arg*` can be a nested DAG construct. This class provides getters to 136 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper 137 // methods. 138 // 139 // A null DagNode contains a nullptr and converts to false implicitly. 140 class DagNode { 141 public: DagNode(const llvm::DagInit * node)142 explicit DagNode(const llvm::DagInit *node) : node(node) {} 143 144 // Implicit bool converter that returns true if this DagNode is not a null 145 // DagNode. 146 operator bool() const { return node != nullptr; } 147 148 // Returns the symbol bound to this DAG node. 149 StringRef getSymbol() const; 150 151 // Returns the operator wrapper object corresponding to the dialect op matched 152 // by this DAG. The operator wrapper will be queried from the given `mapper` 153 // and created in it if not existing. 154 Operator &getDialectOp(RecordOperatorMap *mapper) const; 155 156 // Returns the number of operations recursively involved in the DAG tree 157 // rooted from this node. 158 int getNumOps() const; 159 160 // Returns the number of immediate arguments to this DAG node. 161 int getNumArgs() const; 162 163 // Returns true if the `index`-th argument is a nested DAG construct. 164 bool isNestedDagArg(unsigned index) const; 165 166 // Gets the `index`-th argument as a nested DAG construct if possible. Returns 167 // null DagNode otherwise. 168 DagNode getArgAsNestedDag(unsigned index) const; 169 170 // Gets the `index`-th argument as a DAG leaf. 171 DagLeaf getArgAsLeaf(unsigned index) const; 172 173 // Returns the specified name of the `index`-th argument. 174 StringRef getArgName(unsigned index) const; 175 176 // Returns true if this DAG construct means to replace with an existing SSA 177 // value. 178 bool isReplaceWithValue() const; 179 180 // Returns whether this DAG represents the location of an op creation. 181 bool isLocationDirective() const; 182 183 // Returns whether this DAG is a return type specifier. 184 bool isReturnTypeDirective() const; 185 186 // Returns true if this DAG node is wrapping native code call. 187 bool isNativeCodeCall() const; 188 189 // Returns whether this DAG is an `either` specifier. 190 bool isEither() const; 191 192 // Returns true if this DAG node is an operation. 193 bool isOperation() const; 194 195 // Returns the native code call template inside this DAG node. 196 // Precondition: isNativeCodeCall() 197 StringRef getNativeCodeTemplate() const; 198 199 // Returns the number of values will be returned by the native helper 200 // function. 201 // Precondition: isNativeCodeCall() 202 int getNumReturnsOfNativeCode() const; 203 204 void print(raw_ostream &os) const; 205 206 private: 207 friend class SymbolInfoMap; 208 friend llvm::DenseMapInfo<DagNode>; getAsOpaquePointer()209 const void *getAsOpaquePointer() const { return node; } 210 211 const llvm::DagInit *node; // nullptr means null DagNode 212 }; 213 214 // A class for maintaining information for symbols bound in patterns and 215 // provides methods for resolving them according to specific use cases. 216 // 217 // Symbols can be bound to 218 // 219 // * Op arguments and op results in the source pattern and 220 // * Op results in result patterns. 221 // 222 // Symbols can be referenced in result patterns and additional constraints to 223 // the pattern. 224 // 225 // For example, in 226 // 227 // ``` 228 // def : Pattern< 229 // (SrcOp:$results1 $arg0, %arg1), 230 // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; 231 // ``` 232 // 233 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to 234 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build 235 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. 236 // 237 // If a symbol binds to a multi-result op and it does not have the `__N` 238 // suffix, the symbol is expanded to represent all results generated by the 239 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 240 // only the N-th *static* result as declared in ODS, and that can still 241 // corresponds to multiple *dynamic* values if the N-th *static* result is 242 // variadic. 243 // 244 // This class keeps track of such symbols and resolves them into their bound 245 // values in a suitable way. 246 class SymbolInfoMap { 247 public: SymbolInfoMap(ArrayRef<SMLoc> loc)248 explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {} 249 250 // Class for information regarding a symbol. 251 class SymbolInfo { 252 public: 253 // Returns a type string of a variable. 254 std::string getVarTypeStr(StringRef name) const; 255 256 // Returns a string for defining a variable named as `name` to store the 257 // value bound by this symbol. 258 std::string getVarDecl(StringRef name) const; 259 260 // Returns a string for defining an argument which passes the reference of 261 // the variable. 262 std::string getArgDecl(StringRef name) const; 263 264 // Returns a variable name for the symbol named as `name`. 265 std::string getVarName(StringRef name) const; 266 267 private: 268 // Allow SymbolInfoMap to access private methods. 269 friend class SymbolInfoMap; 270 271 // DagNode and DagLeaf are accessed by value which means it can't be used as 272 // identifier here. Use an opaque pointer type instead. 273 using DagAndConstant = std::pair<const void *, int>; 274 275 // What kind of entity this symbol represents: 276 // * Attr: op attribute 277 // * Operand: op operand 278 // * Result: op result 279 // * Value: a value not attached to an op (e.g., from NativeCodeCall) 280 // * MultipleValues: a pack of values not attached to an op (e.g., from 281 // NativeCodeCall). This kind supports indexing. 282 enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues }; 283 284 // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr` 285 // and `Operand` so should be llvm::None for `Result` and `Value` kind. 286 SymbolInfo(const Operator *op, Kind kind, 287 Optional<DagAndConstant> dagAndConstant); 288 289 // Static methods for creating SymbolInfo. getAttr(const Operator * op,int index)290 static SymbolInfo getAttr(const Operator *op, int index) { 291 return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index)); 292 } getAttr()293 static SymbolInfo getAttr() { 294 return SymbolInfo(nullptr, Kind::Attr, llvm::None); 295 } getOperand(DagNode node,const Operator * op,int index)296 static SymbolInfo getOperand(DagNode node, const Operator *op, int index) { 297 return SymbolInfo(op, Kind::Operand, 298 DagAndConstant(node.getAsOpaquePointer(), index)); 299 } getResult(const Operator * op)300 static SymbolInfo getResult(const Operator *op) { 301 return SymbolInfo(op, Kind::Result, llvm::None); 302 } getValue()303 static SymbolInfo getValue() { 304 return SymbolInfo(nullptr, Kind::Value, llvm::None); 305 } getMultipleValues(int numValues)306 static SymbolInfo getMultipleValues(int numValues) { 307 return SymbolInfo(nullptr, Kind::MultipleValues, 308 DagAndConstant(nullptr, numValues)); 309 } 310 311 // Returns the number of static values this symbol corresponds to. 312 // A static value is an operand/result declared in ODS. Normally a symbol 313 // only represents one static value, but symbols bound to op results can 314 // represent more than one if the op is a multi-result op. 315 int getStaticValueCount() const; 316 317 // Returns a string containing the C++ expression for referencing this 318 // symbol as a value (if this symbol represents one static value) or a value 319 // range (if this symbol represents multiple static values). `name` is the 320 // name of the C++ variable that this symbol bounds to. `index` should only 321 // be used for indexing results. `fmt` is used to format each value. 322 // `separator` is used to separate values if this is a value range. 323 std::string getValueAndRangeUse(StringRef name, int index, const char *fmt, 324 const char *separator) const; 325 326 // Returns a string containing the C++ expression for referencing this 327 // symbol as a value range regardless of how many static values this symbol 328 // represents. `name` is the name of the C++ variable that this symbol 329 // bounds to. `index` should only be used for indexing results. `fmt` is 330 // used to format each value. `separator` is used to separate values in the 331 // range. 332 std::string getAllRangeUse(StringRef name, int index, const char *fmt, 333 const char *separator) const; 334 335 // The argument index (for `Attr` and `Operand` only) getArgIndex()336 int getArgIndex() const { return (*dagAndConstant).second; } 337 338 // The number of values in the MultipleValue getSize()339 int getSize() const { return (*dagAndConstant).second; } 340 341 const Operator *op; // The op where the bound entity belongs 342 Kind kind; // The kind of the bound entity 343 344 // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and 345 // the size of MultipleValue symbol). Note that operands may be bound to the 346 // same symbol, use the DagNode and index to distinguish them. For `Attr` 347 // and MultipleValue, the Dag part will be nullptr. 348 Optional<DagAndConstant> dagAndConstant; 349 350 // Alternative name for the symbol. It is used in case the name 351 // is not unique. Applicable for `Operand` only. 352 Optional<std::string> alternativeName; 353 }; 354 355 using BaseT = std::unordered_multimap<std::string, SymbolInfo>; 356 357 // Iterators for accessing all symbols. 358 using iterator = BaseT::iterator; begin()359 iterator begin() { return symbolInfoMap.begin(); } end()360 iterator end() { return symbolInfoMap.end(); } 361 362 // Const iterators for accessing all symbols. 363 using const_iterator = BaseT::const_iterator; begin()364 const_iterator begin() const { return symbolInfoMap.begin(); } end()365 const_iterator end() const { return symbolInfoMap.end(); } 366 367 // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. 368 // Returns false if `symbol` is already bound and symbols are not operands. 369 bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, 370 int argIndex); 371 372 // Binds the given `symbol` to the results the given `op`. Returns false if 373 // `symbol` is already bound. 374 bool bindOpResult(StringRef symbol, const Operator &op); 375 376 // A helper function for dispatching target value binding functions. 377 bool bindValues(StringRef symbol, int numValues = 1); 378 379 // Registers the given `symbol` as bound to the Value(s). Returns false if 380 // `symbol` is already bound. 381 bool bindValue(StringRef symbol); 382 383 // Registers the given `symbol` as bound to a MultipleValue. Return false if 384 // `symbol` is already bound. 385 bool bindMultipleValues(StringRef symbol, int numValues); 386 387 // Registers the given `symbol` as bound to an attr. Returns false if `symbol` 388 // is already bound. 389 bool bindAttr(StringRef symbol); 390 391 // Returns true if the given `symbol` is bound. 392 bool contains(StringRef symbol) const; 393 394 // Returns an iterator to the information of the given symbol named as `key`. 395 const_iterator find(StringRef key) const; 396 397 // Returns an iterator to the information of the given symbol named as `key`, 398 // with index `argIndex` for operator `op`. 399 const_iterator findBoundSymbol(StringRef key, DagNode node, 400 const Operator &op, int argIndex) const; 401 const_iterator findBoundSymbol(StringRef key, 402 const SymbolInfo &symbolInfo) const; 403 404 // Returns the bounds of a range that includes all the elements which 405 // bind to the `key`. 406 std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key); 407 408 // Returns number of times symbol named as `key` was used. 409 int count(StringRef key) const; 410 411 // Returns the number of static values of the given `symbol` corresponds to. 412 // A static value is an operand/result declared in ODS. Normally a symbol only 413 // represents one static value, but symbols bound to op results can represent 414 // more than one if the op is a multi-result op. 415 int getStaticValueCount(StringRef symbol) const; 416 417 // Returns a string containing the C++ expression for referencing this 418 // symbol as a value (if this symbol represents one static value) or a value 419 // range (if this symbol represents multiple static values). `fmt` is used to 420 // format each value. `separator` is used to separate values if `symbol` 421 // represents a value range. 422 std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}", 423 const char *separator = ", ") const; 424 425 // Returns a string containing the C++ expression for referencing this 426 // symbol as a value range regardless of how many static values this symbol 427 // represents. `fmt` is used to format each value. `separator` is used to 428 // separate values in the range. 429 std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}", 430 const char *separator = ", ") const; 431 432 // Assign alternative unique names to Operands that have equal names. 433 void assignUniqueAlternativeNames(); 434 435 // Splits the given `symbol` into a value pack name and an index. Returns the 436 // value pack name and writes the index to `index` on success. Returns 437 // `symbol` itself if it does not contain an index. 438 // 439 // We can use `name__N` to access the `N`-th value in the value pack bound to 440 // `name`. `name` is typically the results of an multi-result op. 441 static StringRef getValuePackName(StringRef symbol, int *index = nullptr); 442 443 private: 444 BaseT symbolInfoMap; 445 446 // Pattern instantiation location. This is intended to be used as parameter 447 // to PrintFatalError() to report errors. 448 ArrayRef<SMLoc> loc; 449 }; 450 451 // Wrapper class providing helper methods for accessing MLIR Pattern defined 452 // in TableGen. This class should closely reflect what is defined as class 453 // `Pattern` in TableGen. This class contains maps so it is not intended to be 454 // used as values. 455 class Pattern { 456 public: 457 explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); 458 459 // Returns the source pattern to match. 460 DagNode getSourcePattern() const; 461 462 // Returns the number of result patterns generated by applying this rewrite 463 // rule. 464 int getNumResultPatterns() const; 465 466 // Returns the DAG tree root node of the `index`-th result pattern. 467 DagNode getResultPattern(unsigned index) const; 468 469 // Collects all symbols bound in the source pattern into `infoMap`. 470 void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); 471 472 // Collects all symbols bound in result patterns into `infoMap`. 473 void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); 474 475 // Returns the op that the root node of the source pattern matches. 476 const Operator &getSourceRootOp(); 477 478 // Returns the operator wrapper object corresponding to the given `node`'s DAG 479 // operator. 480 Operator &getDialectOp(DagNode node); 481 482 // Returns the constraints. 483 std::vector<AppliedConstraint> getConstraints() const; 484 485 // Returns the benefit score of the pattern. 486 int getBenefit() const; 487 488 using IdentifierLine = std::pair<StringRef, unsigned>; 489 490 // Returns the file location of the pattern (buffer identifier + line number 491 // pair). 492 std::vector<IdentifierLine> getLocation() const; 493 494 // Recursively collects all bound symbols inside the DAG tree rooted 495 // at `tree` and updates the given `infoMap`. 496 void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 497 bool isSrcPattern); 498 499 private: 500 // Helper function to verify variable binding. 501 void verifyBind(bool result, StringRef symbolName); 502 503 // The TableGen definition of this pattern. 504 const llvm::Record &def; 505 506 // All operators. 507 // TODO: we need a proper context manager, like MLIRContext, for managing the 508 // lifetime of shared entities. 509 RecordOperatorMap *recordOpMap; 510 }; 511 512 } // namespace tblgen 513 } // namespace mlir 514 515 namespace llvm { 516 template <> 517 struct DenseMapInfo<mlir::tblgen::DagNode> { 518 static mlir::tblgen::DagNode getEmptyKey() { 519 return mlir::tblgen::DagNode( 520 llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey()); 521 } 522 static mlir::tblgen::DagNode getTombstoneKey() { 523 return mlir::tblgen::DagNode( 524 llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey()); 525 } 526 static unsigned getHashValue(mlir::tblgen::DagNode node) { 527 return llvm::hash_value(node.getAsOpaquePointer()); 528 } 529 static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) { 530 return lhs.node == rhs.node; 531 } 532 }; 533 534 template <> 535 struct DenseMapInfo<mlir::tblgen::DagLeaf> { 536 static mlir::tblgen::DagLeaf getEmptyKey() { 537 return mlir::tblgen::DagLeaf( 538 llvm::DenseMapInfo<llvm::Init *>::getEmptyKey()); 539 } 540 static mlir::tblgen::DagLeaf getTombstoneKey() { 541 return mlir::tblgen::DagLeaf( 542 llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey()); 543 } 544 static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) { 545 return llvm::hash_value(leaf.getAsOpaquePointer()); 546 } 547 static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) { 548 return lhs.def == rhs.def; 549 } 550 }; 551 } // namespace llvm 552 553 #endif // MLIR_TABLEGEN_PATTERN_H_ 554