1 //===- Predicate.cpp - Predicate class ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Wrapper around predicates defined in TableGen. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Predicate.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/StringSwitch.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace tblgen; 24 25 // Construct a Predicate from a record. 26 Pred::Pred(const llvm::Record *record) : def(record) { 27 assert(def->isSubClassOf("Pred") && 28 "must be a subclass of TableGen 'Pred' class"); 29 } 30 31 // Construct a Predicate from an initializer. 32 Pred::Pred(const llvm::Init *init) { 33 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) 34 def = defInit->getDef(); 35 } 36 37 std::string Pred::getCondition() const { 38 // Static dispatch to subclasses. 39 if (def->isSubClassOf("CombinedPred")) 40 return static_cast<const CombinedPred *>(this)->getConditionImpl(); 41 if (def->isSubClassOf("CPred")) 42 return static_cast<const CPred *>(this)->getConditionImpl(); 43 llvm_unreachable("Pred::getCondition must be overridden in subclasses"); 44 } 45 46 bool Pred::isCombined() const { 47 return def && def->isSubClassOf("CombinedPred"); 48 } 49 50 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); } 51 52 CPred::CPred(const llvm::Record *record) : Pred(record) { 53 assert(def->isSubClassOf("CPred") && 54 "must be a subclass of Tablegen 'CPred' class"); 55 } 56 57 CPred::CPred(const llvm::Init *init) : Pred(init) { 58 assert((!def || def->isSubClassOf("CPred")) && 59 "must be a subclass of Tablegen 'CPred' class"); 60 } 61 62 // Get condition of the C Predicate. 63 std::string CPred::getConditionImpl() const { 64 assert(!isNull() && "null predicate does not have a condition"); 65 return std::string(def->getValueAsString("predExpr")); 66 } 67 68 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { 69 assert(def->isSubClassOf("CombinedPred") && 70 "must be a subclass of Tablegen 'CombinedPred' class"); 71 } 72 73 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { 74 assert((!def || def->isSubClassOf("CombinedPred")) && 75 "must be a subclass of Tablegen 'CombinedPred' class"); 76 } 77 78 const llvm::Record *CombinedPred::getCombinerDef() const { 79 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); 80 return def->getValueAsDef("kind"); 81 } 82 83 std::vector<llvm::Record *> CombinedPred::getChildren() const { 84 assert(def->getValue("children") && 85 "CombinedPred must have a value 'children'"); 86 return def->getValueAsListOfDefs("children"); 87 } 88 89 namespace { 90 // Kinds of nodes in a logical predicate tree. 91 enum class PredCombinerKind { 92 Leaf, 93 And, 94 Or, 95 Not, 96 SubstLeaves, 97 Concat, 98 // Special kinds that are used in simplification. 99 False, 100 True 101 }; 102 103 // A node in a logical predicate tree. 104 struct PredNode { 105 PredCombinerKind kind; 106 const Pred *predicate; 107 SmallVector<PredNode *, 4> children; 108 std::string expr; 109 110 // Prefix and suffix are used by ConcatPred. 111 std::string prefix; 112 std::string suffix; 113 }; 114 } // namespace 115 116 // Get a predicate tree node kind based on the kind used in the predicate 117 // TableGen record. 118 static PredCombinerKind getPredCombinerKind(const Pred &pred) { 119 if (!pred.isCombined()) 120 return PredCombinerKind::Leaf; 121 122 const auto &combinedPred = static_cast<const CombinedPred &>(pred); 123 return StringSwitch<PredCombinerKind>( 124 combinedPred.getCombinerDef()->getName()) 125 .Case("PredCombinerAnd", PredCombinerKind::And) 126 .Case("PredCombinerOr", PredCombinerKind::Or) 127 .Case("PredCombinerNot", PredCombinerKind::Not) 128 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) 129 .Case("PredCombinerConcat", PredCombinerKind::Concat); 130 } 131 132 namespace { 133 // Substitution<pattern, replacement>. 134 using Subst = std::pair<StringRef, StringRef>; 135 } // namespace 136 137 /// Perform the given substitutions on 'str' in-place. 138 static void performSubstitutions(std::string &str, 139 ArrayRef<Subst> substitutions) { 140 // Apply all parent substitutions from innermost to outermost. 141 for (const auto &subst : llvm::reverse(substitutions)) { 142 auto pos = str.find(std::string(subst.first)); 143 while (pos != std::string::npos) { 144 str.replace(pos, subst.first.size(), std::string(subst.second)); 145 // Skip the newly inserted substring, which itself may consider the 146 // pattern to match. 147 pos += subst.second.size(); 148 // Find the next possible match position. 149 pos = str.find(std::string(subst.first), pos); 150 } 151 } 152 } 153 154 // Build the predicate tree starting from the top-level predicate, which may 155 // have children, and perform leaf substitutions inplace. Note that after 156 // substitution, nodes are still pointing to the original TableGen record. 157 // All nodes are created within "allocator". 158 static PredNode * 159 buildPredicateTree(const Pred &root, 160 llvm::SpecificBumpPtrAllocator<PredNode> &allocator, 161 ArrayRef<Subst> substitutions) { 162 auto *rootNode = allocator.Allocate(); 163 new (rootNode) PredNode; 164 rootNode->kind = getPredCombinerKind(root); 165 rootNode->predicate = &root; 166 if (!root.isCombined()) { 167 rootNode->expr = root.getCondition(); 168 performSubstitutions(rootNode->expr, substitutions); 169 return rootNode; 170 } 171 172 // If the current combined predicate is a leaf substitution, append it to the 173 // list before continuing. 174 auto allSubstitutions = llvm::to_vector<4>(substitutions); 175 if (rootNode->kind == PredCombinerKind::SubstLeaves) { 176 const auto &substPred = static_cast<const SubstLeavesPred &>(root); 177 allSubstitutions.push_back( 178 {substPred.getPattern(), substPred.getReplacement()}); 179 180 // If the current predicate is a ConcatPred, record the prefix and suffix. 181 } else if (rootNode->kind == PredCombinerKind::Concat) { 182 const auto &concatPred = static_cast<const ConcatPred &>(root); 183 rootNode->prefix = std::string(concatPred.getPrefix()); 184 performSubstitutions(rootNode->prefix, substitutions); 185 rootNode->suffix = std::string(concatPred.getSuffix()); 186 performSubstitutions(rootNode->suffix, substitutions); 187 } 188 189 // Build child subtrees. 190 auto combined = static_cast<const CombinedPred &>(root); 191 for (const auto *record : combined.getChildren()) { 192 auto *childTree = 193 buildPredicateTree(Pred(record), allocator, allSubstitutions); 194 rootNode->children.push_back(childTree); 195 } 196 return rootNode; 197 } 198 199 // Simplify a predicate tree rooted at "node" using the predicates that are 200 // known to be true(false). For AND(OR) combined predicates, if any of the 201 // children is known to be false(true), the result is also false(true). 202 // Furthermore, for AND(OR) combined predicates, children that are known to be 203 // true(false) don't have to be checked dynamically. 204 static PredNode * 205 propagateGroundTruth(PredNode *node, 206 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, 207 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { 208 // If the current predicate is known to be true or false, change the kind of 209 // the node and return immediately. 210 if (knownTruePreds.count(node->predicate) != 0) { 211 node->kind = PredCombinerKind::True; 212 node->children.clear(); 213 return node; 214 } 215 if (knownFalsePreds.count(node->predicate) != 0) { 216 node->kind = PredCombinerKind::False; 217 node->children.clear(); 218 return node; 219 } 220 221 // If the current node is a substitution, stop recursion now. 222 // The expressions in the leaves below this node were rewritten, but the nodes 223 // still point to the original predicate records. While the original 224 // predicate may be known to be true or false, it is not necessarily the case 225 // after rewriting. 226 // TODO: we can support ground truth for rewritten 227 // predicates by either (a) having our own unique'ing of the predicates 228 // instead of relying on TableGen record pointers or (b) taking ground truth 229 // values optionally prefixed with a list of substitutions to apply, e.g. 230 // "predX is true by itself as well as predSubY leaf substitution had been 231 // applied to it". 232 if (node->kind == PredCombinerKind::SubstLeaves) { 233 return node; 234 } 235 236 // Otherwise, look at child nodes. 237 238 // Move child nodes into some local variable so that they can be optimized 239 // separately and re-added if necessary. 240 llvm::SmallVector<PredNode *, 4> children; 241 std::swap(node->children, children); 242 243 for (auto &child : children) { 244 // First, simplify the child. This maintains the predicate as it was. 245 auto *simplifiedChild = 246 propagateGroundTruth(child, knownTruePreds, knownFalsePreds); 247 248 // Just add the child if we don't know how to simplify the current node. 249 if (node->kind != PredCombinerKind::And && 250 node->kind != PredCombinerKind::Or) { 251 node->children.push_back(simplifiedChild); 252 continue; 253 } 254 255 // Second, based on the type define which known values of child predicates 256 // immediately collapse this predicate to a known value, and which others 257 // may be safely ignored. 258 // OR(..., True, ...) = True 259 // OR(..., False, ...) = OR(..., ...) 260 // AND(..., False, ...) = False 261 // AND(..., True, ...) = AND(..., ...) 262 auto collapseKind = node->kind == PredCombinerKind::And 263 ? PredCombinerKind::False 264 : PredCombinerKind::True; 265 auto eraseKind = node->kind == PredCombinerKind::And 266 ? PredCombinerKind::True 267 : PredCombinerKind::False; 268 const auto &collapseList = 269 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; 270 const auto &eraseList = 271 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; 272 if (simplifiedChild->kind == collapseKind || 273 collapseList.count(simplifiedChild->predicate) != 0) { 274 node->kind = collapseKind; 275 node->children.clear(); 276 return node; 277 } 278 if (simplifiedChild->kind == eraseKind || 279 eraseList.count(simplifiedChild->predicate) != 0) { 280 continue; 281 } 282 node->children.push_back(simplifiedChild); 283 } 284 return node; 285 } 286 287 // Combine a list of predicate expressions using a binary combiner. If a list 288 // is empty, return "init". 289 static std::string combineBinary(ArrayRef<std::string> children, 290 const std::string &combiner, 291 std::string init) { 292 if (children.empty()) 293 return init; 294 295 auto size = children.size(); 296 if (size == 1) 297 return children.front(); 298 299 std::string str; 300 llvm::raw_string_ostream os(str); 301 os << '(' << children.front() << ')'; 302 for (unsigned i = 1; i < size; ++i) { 303 os << ' ' << combiner << " (" << children[i] << ')'; 304 } 305 return os.str(); 306 } 307 308 // Prepend negation to the only condition in the predicate expression list. 309 static std::string combineNot(ArrayRef<std::string> children) { 310 assert(children.size() == 1 && "expected exactly one child predicate of Neg"); 311 return (Twine("!(") + children.front() + Twine(')')).str(); 312 } 313 314 // Recursively traverse the predicate tree in depth-first post-order and build 315 // the final expression. 316 static std::string getCombinedCondition(const PredNode &root) { 317 // Immediately return for non-combiner predicates that don't have children. 318 if (root.kind == PredCombinerKind::Leaf) 319 return root.expr; 320 if (root.kind == PredCombinerKind::True) 321 return "true"; 322 if (root.kind == PredCombinerKind::False) 323 return "false"; 324 325 // Recurse into children. 326 llvm::SmallVector<std::string, 4> childExpressions; 327 childExpressions.reserve(root.children.size()); 328 for (const auto &child : root.children) 329 childExpressions.push_back(getCombinedCondition(*child)); 330 331 // Combine the expressions based on the predicate node kind. 332 if (root.kind == PredCombinerKind::And) 333 return combineBinary(childExpressions, "&&", "true"); 334 if (root.kind == PredCombinerKind::Or) 335 return combineBinary(childExpressions, "||", "false"); 336 if (root.kind == PredCombinerKind::Not) 337 return combineNot(childExpressions); 338 if (root.kind == PredCombinerKind::Concat) { 339 assert(childExpressions.size() == 1 && 340 "ConcatPred should only have one child"); 341 return root.prefix + childExpressions.front() + root.suffix; 342 } 343 344 // Substitutions were applied before so just ignore them. 345 if (root.kind == PredCombinerKind::SubstLeaves) { 346 assert(childExpressions.size() == 1 && 347 "substitution predicate must have one child"); 348 return childExpressions[0]; 349 } 350 351 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); 352 } 353 354 std::string CombinedPred::getConditionImpl() const { 355 llvm::SpecificBumpPtrAllocator<PredNode> allocator; 356 auto *predicateTree = buildPredicateTree(*this, allocator, {}); 357 predicateTree = 358 propagateGroundTruth(predicateTree, 359 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), 360 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); 361 362 return getCombinedCondition(*predicateTree); 363 } 364 365 StringRef SubstLeavesPred::getPattern() const { 366 return def->getValueAsString("pattern"); 367 } 368 369 StringRef SubstLeavesPred::getReplacement() const { 370 return def->getValueAsString("replacement"); 371 } 372 373 StringRef ConcatPred::getPrefix() const { 374 return def->getValueAsString("prefix"); 375 } 376 377 StringRef ConcatPred::getSuffix() const { 378 return def->getValueAsString("suffix"); 379 } 380