1 //===- Predicate.cpp - Predicate class ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Wrapper around predicates defined in TableGen. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Predicate.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/SmallPtrSet.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Record.h" 20 21 using namespace mlir; 22 23 // Construct a Predicate from a record. 24 tblgen::Pred::Pred(const llvm::Record *record) : def(record) { 25 assert(def->isSubClassOf("Pred") && 26 "must be a subclass of TableGen 'Pred' class"); 27 } 28 29 // Construct a Predicate from an initializer. 30 tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) { 31 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) 32 def = defInit->getDef(); 33 } 34 35 std::string tblgen::Pred::getCondition() const { 36 // Static dispatch to subclasses. 37 if (def->isSubClassOf("CombinedPred")) 38 return static_cast<const CombinedPred *>(this)->getConditionImpl(); 39 if (def->isSubClassOf("CPred")) 40 return static_cast<const CPred *>(this)->getConditionImpl(); 41 llvm_unreachable("Pred::getCondition must be overridden in subclasses"); 42 } 43 44 bool tblgen::Pred::isCombined() const { 45 return def && def->isSubClassOf("CombinedPred"); 46 } 47 48 ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); } 49 50 tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) { 51 assert(def->isSubClassOf("CPred") && 52 "must be a subclass of Tablegen 'CPred' class"); 53 } 54 55 tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) { 56 assert((!def || def->isSubClassOf("CPred")) && 57 "must be a subclass of Tablegen 'CPred' class"); 58 } 59 60 // Get condition of the C Predicate. 61 std::string tblgen::CPred::getConditionImpl() const { 62 assert(!isNull() && "null predicate does not have a condition"); 63 return std::string(def->getValueAsString("predExpr")); 64 } 65 66 tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { 67 assert(def->isSubClassOf("CombinedPred") && 68 "must be a subclass of Tablegen 'CombinedPred' class"); 69 } 70 71 tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { 72 assert((!def || def->isSubClassOf("CombinedPred")) && 73 "must be a subclass of Tablegen 'CombinedPred' class"); 74 } 75 76 const llvm::Record *tblgen::CombinedPred::getCombinerDef() const { 77 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); 78 return def->getValueAsDef("kind"); 79 } 80 81 const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const { 82 assert(def->getValue("children") && 83 "CombinedPred must have a value 'children'"); 84 return def->getValueAsListOfDefs("children"); 85 } 86 87 namespace { 88 // Kinds of nodes in a logical predicate tree. 89 enum class PredCombinerKind { 90 Leaf, 91 And, 92 Or, 93 Not, 94 SubstLeaves, 95 Concat, 96 // Special kinds that are used in simplification. 97 False, 98 True 99 }; 100 101 // A node in a logical predicate tree. 102 struct PredNode { 103 PredCombinerKind kind; 104 const tblgen::Pred *predicate; 105 SmallVector<PredNode *, 4> children; 106 std::string expr; 107 108 // Prefix and suffix are used by ConcatPred. 109 std::string prefix; 110 std::string suffix; 111 }; 112 } // end anonymous namespace 113 114 // Get a predicate tree node kind based on the kind used in the predicate 115 // TableGen record. 116 static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) { 117 if (!pred.isCombined()) 118 return PredCombinerKind::Leaf; 119 120 const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred); 121 return llvm::StringSwitch<PredCombinerKind>( 122 combinedPred.getCombinerDef()->getName()) 123 .Case("PredCombinerAnd", PredCombinerKind::And) 124 .Case("PredCombinerOr", PredCombinerKind::Or) 125 .Case("PredCombinerNot", PredCombinerKind::Not) 126 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) 127 .Case("PredCombinerConcat", PredCombinerKind::Concat); 128 } 129 130 namespace { 131 // Substitution<pattern, replacement>. 132 using Subst = std::pair<StringRef, StringRef>; 133 } // end anonymous namespace 134 135 // Build the predicate tree starting from the top-level predicate, which may 136 // have children, and perform leaf substitutions inplace. Note that after 137 // substitution, nodes are still pointing to the original TableGen record. 138 // All nodes are created within "allocator". 139 static PredNode * 140 buildPredicateTree(const tblgen::Pred &root, 141 llvm::SpecificBumpPtrAllocator<PredNode> &allocator, 142 ArrayRef<Subst> substitutions) { 143 auto *rootNode = allocator.Allocate(); 144 new (rootNode) PredNode; 145 rootNode->kind = getPredCombinerKind(root); 146 rootNode->predicate = &root; 147 if (!root.isCombined()) { 148 rootNode->expr = root.getCondition(); 149 // Apply all parent substitutions from innermost to outermost. 150 for (const auto &subst : llvm::reverse(substitutions)) { 151 auto pos = rootNode->expr.find(std::string(subst.first)); 152 while (pos != std::string::npos) { 153 rootNode->expr.replace(pos, subst.first.size(), 154 std::string(subst.second)); 155 // Skip the newly inserted substring, which itself may consider the 156 // pattern to match. 157 pos += subst.second.size(); 158 // Find the next possible match position. 159 pos = rootNode->expr.find(std::string(subst.first), pos); 160 } 161 } 162 return rootNode; 163 } 164 165 // If the current combined predicate is a leaf substitution, append it to the 166 // list before continuing. 167 auto allSubstitutions = llvm::to_vector<4>(substitutions); 168 if (rootNode->kind == PredCombinerKind::SubstLeaves) { 169 const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root); 170 allSubstitutions.push_back( 171 {substPred.getPattern(), substPred.getReplacement()}); 172 } 173 // If the current predicate is a ConcatPred, record the prefix and suffix. 174 else if (rootNode->kind == PredCombinerKind::Concat) { 175 const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root); 176 rootNode->prefix = std::string(concatPred.getPrefix()); 177 rootNode->suffix = std::string(concatPred.getSuffix()); 178 } 179 180 // Build child subtrees. 181 auto combined = static_cast<const tblgen::CombinedPred &>(root); 182 for (const auto *record : combined.getChildren()) { 183 auto childTree = 184 buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions); 185 rootNode->children.push_back(childTree); 186 } 187 return rootNode; 188 } 189 190 // Simplify a predicate tree rooted at "node" using the predicates that are 191 // known to be true(false). For AND(OR) combined predicates, if any of the 192 // children is known to be false(true), the result is also false(true). 193 // Furthermore, for AND(OR) combined predicates, children that are known to be 194 // true(false) don't have to be checked dynamically. 195 static PredNode *propagateGroundTruth( 196 PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds, 197 const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) { 198 // If the current predicate is known to be true or false, change the kind of 199 // the node and return immediately. 200 if (knownTruePreds.count(node->predicate) != 0) { 201 node->kind = PredCombinerKind::True; 202 node->children.clear(); 203 return node; 204 } 205 if (knownFalsePreds.count(node->predicate) != 0) { 206 node->kind = PredCombinerKind::False; 207 node->children.clear(); 208 return node; 209 } 210 211 // If the current node is a substitution, stop recursion now. 212 // The expressions in the leaves below this node were rewritten, but the nodes 213 // still point to the original predicate records. While the original 214 // predicate may be known to be true or false, it is not necessarily the case 215 // after rewriting. 216 // TODO(zinenko,jpienaar): we can support ground truth for rewritten 217 // predicates by either (a) having our own unique'ing of the predicates 218 // instead of relying on TableGen record pointers or (b) taking ground truth 219 // values optionally prefixed with a list of substitutions to apply, e.g. 220 // "predX is true by itself as well as predSubY leaf substitution had been 221 // applied to it". 222 if (node->kind == PredCombinerKind::SubstLeaves) { 223 return node; 224 } 225 226 // Otherwise, look at child nodes. 227 228 // Move child nodes into some local variable so that they can be optimized 229 // separately and re-added if necessary. 230 llvm::SmallVector<PredNode *, 4> children; 231 std::swap(node->children, children); 232 233 for (auto &child : children) { 234 // First, simplify the child. This maintains the predicate as it was. 235 auto simplifiedChild = 236 propagateGroundTruth(child, knownTruePreds, knownFalsePreds); 237 238 // Just add the child if we don't know how to simplify the current node. 239 if (node->kind != PredCombinerKind::And && 240 node->kind != PredCombinerKind::Or) { 241 node->children.push_back(simplifiedChild); 242 continue; 243 } 244 245 // Second, based on the type define which known values of child predicates 246 // immediately collapse this predicate to a known value, and which others 247 // may be safely ignored. 248 // OR(..., True, ...) = True 249 // OR(..., False, ...) = OR(..., ...) 250 // AND(..., False, ...) = False 251 // AND(..., True, ...) = AND(..., ...) 252 auto collapseKind = node->kind == PredCombinerKind::And 253 ? PredCombinerKind::False 254 : PredCombinerKind::True; 255 auto eraseKind = node->kind == PredCombinerKind::And 256 ? PredCombinerKind::True 257 : PredCombinerKind::False; 258 const auto &collapseList = 259 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; 260 const auto &eraseList = 261 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; 262 if (simplifiedChild->kind == collapseKind || 263 collapseList.count(simplifiedChild->predicate) != 0) { 264 node->kind = collapseKind; 265 node->children.clear(); 266 return node; 267 } else if (simplifiedChild->kind == eraseKind || 268 eraseList.count(simplifiedChild->predicate) != 0) { 269 continue; 270 } 271 node->children.push_back(simplifiedChild); 272 } 273 return node; 274 } 275 276 // Combine a list of predicate expressions using a binary combiner. If a list 277 // is empty, return "init". 278 static std::string combineBinary(ArrayRef<std::string> children, 279 std::string combiner, std::string init) { 280 if (children.empty()) 281 return init; 282 283 auto size = children.size(); 284 if (size == 1) 285 return children.front(); 286 287 std::string str; 288 llvm::raw_string_ostream os(str); 289 os << '(' << children.front() << ')'; 290 for (unsigned i = 1; i < size; ++i) { 291 os << ' ' << combiner << " (" << children[i] << ')'; 292 } 293 return os.str(); 294 } 295 296 // Prepend negation to the only condition in the predicate expression list. 297 static std::string combineNot(ArrayRef<std::string> children) { 298 assert(children.size() == 1 && "expected exactly one child predicate of Neg"); 299 return (Twine("!(") + children.front() + Twine(')')).str(); 300 } 301 302 // Recursively traverse the predicate tree in depth-first post-order and build 303 // the final expression. 304 static std::string getCombinedCondition(const PredNode &root) { 305 // Immediately return for non-combiner predicates that don't have children. 306 if (root.kind == PredCombinerKind::Leaf) 307 return root.expr; 308 if (root.kind == PredCombinerKind::True) 309 return "true"; 310 if (root.kind == PredCombinerKind::False) 311 return "false"; 312 313 // Recurse into children. 314 llvm::SmallVector<std::string, 4> childExpressions; 315 childExpressions.reserve(root.children.size()); 316 for (const auto &child : root.children) 317 childExpressions.push_back(getCombinedCondition(*child)); 318 319 // Combine the expressions based on the predicate node kind. 320 if (root.kind == PredCombinerKind::And) 321 return combineBinary(childExpressions, "&&", "true"); 322 if (root.kind == PredCombinerKind::Or) 323 return combineBinary(childExpressions, "||", "false"); 324 if (root.kind == PredCombinerKind::Not) 325 return combineNot(childExpressions); 326 if (root.kind == PredCombinerKind::Concat) { 327 assert(childExpressions.size() == 1 && 328 "ConcatPred should only have one child"); 329 return root.prefix + childExpressions.front() + root.suffix; 330 } 331 332 // Substitutions were applied before so just ignore them. 333 if (root.kind == PredCombinerKind::SubstLeaves) { 334 assert(childExpressions.size() == 1 && 335 "substitution predicate must have one child"); 336 return childExpressions[0]; 337 } 338 339 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); 340 } 341 342 std::string tblgen::CombinedPred::getConditionImpl() const { 343 llvm::SpecificBumpPtrAllocator<PredNode> allocator; 344 auto predicateTree = buildPredicateTree(*this, allocator, {}); 345 predicateTree = propagateGroundTruth( 346 predicateTree, 347 /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(), 348 /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>()); 349 350 return getCombinedCondition(*predicateTree); 351 } 352 353 StringRef tblgen::SubstLeavesPred::getPattern() const { 354 return def->getValueAsString("pattern"); 355 } 356 357 StringRef tblgen::SubstLeavesPred::getReplacement() const { 358 return def->getValueAsString("replacement"); 359 } 360 361 StringRef tblgen::ConcatPred::getPrefix() const { 362 return def->getValueAsString("prefix"); 363 } 364 365 StringRef tblgen::ConcatPred::getSuffix() const { 366 return def->getValueAsString("suffix"); 367 } 368