1 //===- Predicate.cpp - Predicate class ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Wrapper around predicates defined in TableGen. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Predicate.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Record.h" 20 21 using namespace mlir; 22 using namespace tblgen; 23 24 // Construct a Predicate from a record. 25 Pred::Pred(const llvm::Record *record) : def(record) { 26 assert(def->isSubClassOf("Pred") && 27 "must be a subclass of TableGen 'Pred' class"); 28 } 29 30 // Construct a Predicate from an initializer. 31 Pred::Pred(const llvm::Init *init) : def(nullptr) { 32 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) 33 def = defInit->getDef(); 34 } 35 36 std::string Pred::getCondition() const { 37 // Static dispatch to subclasses. 38 if (def->isSubClassOf("CombinedPred")) 39 return static_cast<const CombinedPred *>(this)->getConditionImpl(); 40 if (def->isSubClassOf("CPred")) 41 return static_cast<const CPred *>(this)->getConditionImpl(); 42 llvm_unreachable("Pred::getCondition must be overridden in subclasses"); 43 } 44 45 bool Pred::isCombined() const { 46 return def && def->isSubClassOf("CombinedPred"); 47 } 48 49 ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); } 50 51 CPred::CPred(const llvm::Record *record) : Pred(record) { 52 assert(def->isSubClassOf("CPred") && 53 "must be a subclass of Tablegen 'CPred' class"); 54 } 55 56 CPred::CPred(const llvm::Init *init) : Pred(init) { 57 assert((!def || def->isSubClassOf("CPred")) && 58 "must be a subclass of Tablegen 'CPred' class"); 59 } 60 61 // Get condition of the C Predicate. 62 std::string CPred::getConditionImpl() const { 63 assert(!isNull() && "null predicate does not have a condition"); 64 return std::string(def->getValueAsString("predExpr")); 65 } 66 67 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { 68 assert(def->isSubClassOf("CombinedPred") && 69 "must be a subclass of Tablegen 'CombinedPred' class"); 70 } 71 72 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { 73 assert((!def || def->isSubClassOf("CombinedPred")) && 74 "must be a subclass of Tablegen 'CombinedPred' class"); 75 } 76 77 const llvm::Record *CombinedPred::getCombinerDef() const { 78 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); 79 return def->getValueAsDef("kind"); 80 } 81 82 std::vector<llvm::Record *> CombinedPred::getChildren() const { 83 assert(def->getValue("children") && 84 "CombinedPred must have a value 'children'"); 85 return def->getValueAsListOfDefs("children"); 86 } 87 88 namespace { 89 // Kinds of nodes in a logical predicate tree. 90 enum class PredCombinerKind { 91 Leaf, 92 And, 93 Or, 94 Not, 95 SubstLeaves, 96 Concat, 97 // Special kinds that are used in simplification. 98 False, 99 True 100 }; 101 102 // A node in a logical predicate tree. 103 struct PredNode { 104 PredCombinerKind kind; 105 const Pred *predicate; 106 SmallVector<PredNode *, 4> children; 107 std::string expr; 108 109 // Prefix and suffix are used by ConcatPred. 110 std::string prefix; 111 std::string suffix; 112 }; 113 } // namespace 114 115 // Get a predicate tree node kind based on the kind used in the predicate 116 // TableGen record. 117 static PredCombinerKind getPredCombinerKind(const Pred &pred) { 118 if (!pred.isCombined()) 119 return PredCombinerKind::Leaf; 120 121 const auto &combinedPred = static_cast<const CombinedPred &>(pred); 122 return StringSwitch<PredCombinerKind>( 123 combinedPred.getCombinerDef()->getName()) 124 .Case("PredCombinerAnd", PredCombinerKind::And) 125 .Case("PredCombinerOr", PredCombinerKind::Or) 126 .Case("PredCombinerNot", PredCombinerKind::Not) 127 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) 128 .Case("PredCombinerConcat", PredCombinerKind::Concat); 129 } 130 131 namespace { 132 // Substitution<pattern, replacement>. 133 using Subst = std::pair<StringRef, StringRef>; 134 } // namespace 135 136 /// Perform the given substitutions on 'str' in-place. 137 static void performSubstitutions(std::string &str, 138 ArrayRef<Subst> substitutions) { 139 // Apply all parent substitutions from innermost to outermost. 140 for (const auto &subst : llvm::reverse(substitutions)) { 141 auto pos = str.find(std::string(subst.first)); 142 while (pos != std::string::npos) { 143 str.replace(pos, subst.first.size(), std::string(subst.second)); 144 // Skip the newly inserted substring, which itself may consider the 145 // pattern to match. 146 pos += subst.second.size(); 147 // Find the next possible match position. 148 pos = str.find(std::string(subst.first), pos); 149 } 150 } 151 } 152 153 // Build the predicate tree starting from the top-level predicate, which may 154 // have children, and perform leaf substitutions inplace. Note that after 155 // substitution, nodes are still pointing to the original TableGen record. 156 // All nodes are created within "allocator". 157 static PredNode * 158 buildPredicateTree(const Pred &root, 159 llvm::SpecificBumpPtrAllocator<PredNode> &allocator, 160 ArrayRef<Subst> substitutions) { 161 auto *rootNode = allocator.Allocate(); 162 new (rootNode) PredNode; 163 rootNode->kind = getPredCombinerKind(root); 164 rootNode->predicate = &root; 165 if (!root.isCombined()) { 166 rootNode->expr = root.getCondition(); 167 performSubstitutions(rootNode->expr, substitutions); 168 return rootNode; 169 } 170 171 // If the current combined predicate is a leaf substitution, append it to the 172 // list before continuing. 173 auto allSubstitutions = llvm::to_vector<4>(substitutions); 174 if (rootNode->kind == PredCombinerKind::SubstLeaves) { 175 const auto &substPred = static_cast<const SubstLeavesPred &>(root); 176 allSubstitutions.push_back( 177 {substPred.getPattern(), substPred.getReplacement()}); 178 179 // If the current predicate is a ConcatPred, record the prefix and suffix. 180 } else if (rootNode->kind == PredCombinerKind::Concat) { 181 const auto &concatPred = static_cast<const ConcatPred &>(root); 182 rootNode->prefix = std::string(concatPred.getPrefix()); 183 performSubstitutions(rootNode->prefix, substitutions); 184 rootNode->suffix = std::string(concatPred.getSuffix()); 185 performSubstitutions(rootNode->suffix, substitutions); 186 } 187 188 // Build child subtrees. 189 auto combined = static_cast<const CombinedPred &>(root); 190 for (const auto *record : combined.getChildren()) { 191 auto *childTree = 192 buildPredicateTree(Pred(record), allocator, allSubstitutions); 193 rootNode->children.push_back(childTree); 194 } 195 return rootNode; 196 } 197 198 // Simplify a predicate tree rooted at "node" using the predicates that are 199 // known to be true(false). For AND(OR) combined predicates, if any of the 200 // children is known to be false(true), the result is also false(true). 201 // Furthermore, for AND(OR) combined predicates, children that are known to be 202 // true(false) don't have to be checked dynamically. 203 static PredNode * 204 propagateGroundTruth(PredNode *node, 205 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, 206 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { 207 // If the current predicate is known to be true or false, change the kind of 208 // the node and return immediately. 209 if (knownTruePreds.count(node->predicate) != 0) { 210 node->kind = PredCombinerKind::True; 211 node->children.clear(); 212 return node; 213 } 214 if (knownFalsePreds.count(node->predicate) != 0) { 215 node->kind = PredCombinerKind::False; 216 node->children.clear(); 217 return node; 218 } 219 220 // If the current node is a substitution, stop recursion now. 221 // The expressions in the leaves below this node were rewritten, but the nodes 222 // still point to the original predicate records. While the original 223 // predicate may be known to be true or false, it is not necessarily the case 224 // after rewriting. 225 // TODO: we can support ground truth for rewritten 226 // predicates by either (a) having our own unique'ing of the predicates 227 // instead of relying on TableGen record pointers or (b) taking ground truth 228 // values optionally prefixed with a list of substitutions to apply, e.g. 229 // "predX is true by itself as well as predSubY leaf substitution had been 230 // applied to it". 231 if (node->kind == PredCombinerKind::SubstLeaves) { 232 return node; 233 } 234 235 // Otherwise, look at child nodes. 236 237 // Move child nodes into some local variable so that they can be optimized 238 // separately and re-added if necessary. 239 llvm::SmallVector<PredNode *, 4> children; 240 std::swap(node->children, children); 241 242 for (auto &child : children) { 243 // First, simplify the child. This maintains the predicate as it was. 244 auto *simplifiedChild = 245 propagateGroundTruth(child, knownTruePreds, knownFalsePreds); 246 247 // Just add the child if we don't know how to simplify the current node. 248 if (node->kind != PredCombinerKind::And && 249 node->kind != PredCombinerKind::Or) { 250 node->children.push_back(simplifiedChild); 251 continue; 252 } 253 254 // Second, based on the type define which known values of child predicates 255 // immediately collapse this predicate to a known value, and which others 256 // may be safely ignored. 257 // OR(..., True, ...) = True 258 // OR(..., False, ...) = OR(..., ...) 259 // AND(..., False, ...) = False 260 // AND(..., True, ...) = AND(..., ...) 261 auto collapseKind = node->kind == PredCombinerKind::And 262 ? PredCombinerKind::False 263 : PredCombinerKind::True; 264 auto eraseKind = node->kind == PredCombinerKind::And 265 ? PredCombinerKind::True 266 : PredCombinerKind::False; 267 const auto &collapseList = 268 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; 269 const auto &eraseList = 270 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; 271 if (simplifiedChild->kind == collapseKind || 272 collapseList.count(simplifiedChild->predicate) != 0) { 273 node->kind = collapseKind; 274 node->children.clear(); 275 return node; 276 } 277 if (simplifiedChild->kind == eraseKind || 278 eraseList.count(simplifiedChild->predicate) != 0) { 279 continue; 280 } 281 node->children.push_back(simplifiedChild); 282 } 283 return node; 284 } 285 286 // Combine a list of predicate expressions using a binary combiner. If a list 287 // is empty, return "init". 288 static std::string combineBinary(ArrayRef<std::string> children, 289 const std::string &combiner, 290 std::string init) { 291 if (children.empty()) 292 return init; 293 294 auto size = children.size(); 295 if (size == 1) 296 return children.front(); 297 298 std::string str; 299 llvm::raw_string_ostream os(str); 300 os << '(' << children.front() << ')'; 301 for (unsigned i = 1; i < size; ++i) { 302 os << ' ' << combiner << " (" << children[i] << ')'; 303 } 304 return os.str(); 305 } 306 307 // Prepend negation to the only condition in the predicate expression list. 308 static std::string combineNot(ArrayRef<std::string> children) { 309 assert(children.size() == 1 && "expected exactly one child predicate of Neg"); 310 return (Twine("!(") + children.front() + Twine(')')).str(); 311 } 312 313 // Recursively traverse the predicate tree in depth-first post-order and build 314 // the final expression. 315 static std::string getCombinedCondition(const PredNode &root) { 316 // Immediately return for non-combiner predicates that don't have children. 317 if (root.kind == PredCombinerKind::Leaf) 318 return root.expr; 319 if (root.kind == PredCombinerKind::True) 320 return "true"; 321 if (root.kind == PredCombinerKind::False) 322 return "false"; 323 324 // Recurse into children. 325 llvm::SmallVector<std::string, 4> childExpressions; 326 childExpressions.reserve(root.children.size()); 327 for (const auto &child : root.children) 328 childExpressions.push_back(getCombinedCondition(*child)); 329 330 // Combine the expressions based on the predicate node kind. 331 if (root.kind == PredCombinerKind::And) 332 return combineBinary(childExpressions, "&&", "true"); 333 if (root.kind == PredCombinerKind::Or) 334 return combineBinary(childExpressions, "||", "false"); 335 if (root.kind == PredCombinerKind::Not) 336 return combineNot(childExpressions); 337 if (root.kind == PredCombinerKind::Concat) { 338 assert(childExpressions.size() == 1 && 339 "ConcatPred should only have one child"); 340 return root.prefix + childExpressions.front() + root.suffix; 341 } 342 343 // Substitutions were applied before so just ignore them. 344 if (root.kind == PredCombinerKind::SubstLeaves) { 345 assert(childExpressions.size() == 1 && 346 "substitution predicate must have one child"); 347 return childExpressions[0]; 348 } 349 350 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); 351 } 352 353 std::string CombinedPred::getConditionImpl() const { 354 llvm::SpecificBumpPtrAllocator<PredNode> allocator; 355 auto *predicateTree = buildPredicateTree(*this, allocator, {}); 356 predicateTree = 357 propagateGroundTruth(predicateTree, 358 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), 359 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); 360 361 return getCombinedCondition(*predicateTree); 362 } 363 364 StringRef SubstLeavesPred::getPattern() const { 365 return def->getValueAsString("pattern"); 366 } 367 368 StringRef SubstLeavesPred::getReplacement() const { 369 return def->getValueAsString("replacement"); 370 } 371 372 StringRef ConcatPred::getPrefix() const { 373 return def->getValueAsString("prefix"); 374 } 375 376 StringRef ConcatPred::getSuffix() const { 377 return def->getValueAsString("suffix"); 378 } 379