1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace tblgen;
24 
25 // Construct a Predicate from a record.
Pred(const llvm::Record * record)26 Pred::Pred(const llvm::Record *record) : def(record) {
27   assert(def->isSubClassOf("Pred") &&
28          "must be a subclass of TableGen 'Pred' class");
29 }
30 
31 // Construct a Predicate from an initializer.
Pred(const llvm::Init * init)32 Pred::Pred(const llvm::Init *init) {
33   if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
34     def = defInit->getDef();
35 }
36 
getCondition() const37 std::string Pred::getCondition() const {
38   // Static dispatch to subclasses.
39   if (def->isSubClassOf("CombinedPred"))
40     return static_cast<const CombinedPred *>(this)->getConditionImpl();
41   if (def->isSubClassOf("CPred"))
42     return static_cast<const CPred *>(this)->getConditionImpl();
43   llvm_unreachable("Pred::getCondition must be overridden in subclasses");
44 }
45 
isCombined() const46 bool Pred::isCombined() const {
47   return def && def->isSubClassOf("CombinedPred");
48 }
49 
getLoc() const50 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
51 
CPred(const llvm::Record * record)52 CPred::CPred(const llvm::Record *record) : Pred(record) {
53   assert(def->isSubClassOf("CPred") &&
54          "must be a subclass of Tablegen 'CPred' class");
55 }
56 
CPred(const llvm::Init * init)57 CPred::CPred(const llvm::Init *init) : Pred(init) {
58   assert((!def || def->isSubClassOf("CPred")) &&
59          "must be a subclass of Tablegen 'CPred' class");
60 }
61 
62 // Get condition of the C Predicate.
getConditionImpl() const63 std::string CPred::getConditionImpl() const {
64   assert(!isNull() && "null predicate does not have a condition");
65   return std::string(def->getValueAsString("predExpr"));
66 }
67 
CombinedPred(const llvm::Record * record)68 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
69   assert(def->isSubClassOf("CombinedPred") &&
70          "must be a subclass of Tablegen 'CombinedPred' class");
71 }
72 
CombinedPred(const llvm::Init * init)73 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
74   assert((!def || def->isSubClassOf("CombinedPred")) &&
75          "must be a subclass of Tablegen 'CombinedPred' class");
76 }
77 
getCombinerDef() const78 const llvm::Record *CombinedPred::getCombinerDef() const {
79   assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
80   return def->getValueAsDef("kind");
81 }
82 
getChildren() const83 std::vector<llvm::Record *> CombinedPred::getChildren() const {
84   assert(def->getValue("children") &&
85          "CombinedPred must have a value 'children'");
86   return def->getValueAsListOfDefs("children");
87 }
88 
89 namespace {
90 // Kinds of nodes in a logical predicate tree.
91 enum class PredCombinerKind {
92   Leaf,
93   And,
94   Or,
95   Not,
96   SubstLeaves,
97   Concat,
98   // Special kinds that are used in simplification.
99   False,
100   True
101 };
102 
103 // A node in a logical predicate tree.
104 struct PredNode {
105   PredCombinerKind kind;
106   const Pred *predicate;
107   SmallVector<PredNode *, 4> children;
108   std::string expr;
109 
110   // Prefix and suffix are used by ConcatPred.
111   std::string prefix;
112   std::string suffix;
113 };
114 } // namespace
115 
116 // Get a predicate tree node kind based on the kind used in the predicate
117 // TableGen record.
getPredCombinerKind(const Pred & pred)118 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
119   if (!pred.isCombined())
120     return PredCombinerKind::Leaf;
121 
122   const auto &combinedPred = static_cast<const CombinedPred &>(pred);
123   return StringSwitch<PredCombinerKind>(
124              combinedPred.getCombinerDef()->getName())
125       .Case("PredCombinerAnd", PredCombinerKind::And)
126       .Case("PredCombinerOr", PredCombinerKind::Or)
127       .Case("PredCombinerNot", PredCombinerKind::Not)
128       .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
129       .Case("PredCombinerConcat", PredCombinerKind::Concat);
130 }
131 
132 namespace {
133 // Substitution<pattern, replacement>.
134 using Subst = std::pair<StringRef, StringRef>;
135 } // namespace
136 
137 /// Perform the given substitutions on 'str' in-place.
performSubstitutions(std::string & str,ArrayRef<Subst> substitutions)138 static void performSubstitutions(std::string &str,
139                                  ArrayRef<Subst> substitutions) {
140   // Apply all parent substitutions from innermost to outermost.
141   for (const auto &subst : llvm::reverse(substitutions)) {
142     auto pos = str.find(std::string(subst.first));
143     while (pos != std::string::npos) {
144       str.replace(pos, subst.first.size(), std::string(subst.second));
145       // Skip the newly inserted substring, which itself may consider the
146       // pattern to match.
147       pos += subst.second.size();
148       // Find the next possible match position.
149       pos = str.find(std::string(subst.first), pos);
150     }
151   }
152 }
153 
154 // Build the predicate tree starting from the top-level predicate, which may
155 // have children, and perform leaf substitutions inplace.  Note that after
156 // substitution, nodes are still pointing to the original TableGen record.
157 // All nodes are created within "allocator".
158 static PredNode *
buildPredicateTree(const Pred & root,llvm::SpecificBumpPtrAllocator<PredNode> & allocator,ArrayRef<Subst> substitutions)159 buildPredicateTree(const Pred &root,
160                    llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
161                    ArrayRef<Subst> substitutions) {
162   auto *rootNode = allocator.Allocate();
163   new (rootNode) PredNode;
164   rootNode->kind = getPredCombinerKind(root);
165   rootNode->predicate = &root;
166   if (!root.isCombined()) {
167     rootNode->expr = root.getCondition();
168     performSubstitutions(rootNode->expr, substitutions);
169     return rootNode;
170   }
171 
172   // If the current combined predicate is a leaf substitution, append it to the
173   // list before continuing.
174   auto allSubstitutions = llvm::to_vector<4>(substitutions);
175   if (rootNode->kind == PredCombinerKind::SubstLeaves) {
176     const auto &substPred = static_cast<const SubstLeavesPred &>(root);
177     allSubstitutions.push_back(
178         {substPred.getPattern(), substPred.getReplacement()});
179 
180     // If the current predicate is a ConcatPred, record the prefix and suffix.
181   } else if (rootNode->kind == PredCombinerKind::Concat) {
182     const auto &concatPred = static_cast<const ConcatPred &>(root);
183     rootNode->prefix = std::string(concatPred.getPrefix());
184     performSubstitutions(rootNode->prefix, substitutions);
185     rootNode->suffix = std::string(concatPred.getSuffix());
186     performSubstitutions(rootNode->suffix, substitutions);
187   }
188 
189   // Build child subtrees.
190   auto combined = static_cast<const CombinedPred &>(root);
191   for (const auto *record : combined.getChildren()) {
192     auto *childTree =
193         buildPredicateTree(Pred(record), allocator, allSubstitutions);
194     rootNode->children.push_back(childTree);
195   }
196   return rootNode;
197 }
198 
199 // Simplify a predicate tree rooted at "node" using the predicates that are
200 // known to be true(false).  For AND(OR) combined predicates, if any of the
201 // children is known to be false(true), the result is also false(true).
202 // Furthermore, for AND(OR) combined predicates, children that are known to be
203 // true(false) don't have to be checked dynamically.
204 static PredNode *
propagateGroundTruth(PredNode * node,const llvm::SmallPtrSetImpl<Pred * > & knownTruePreds,const llvm::SmallPtrSetImpl<Pred * > & knownFalsePreds)205 propagateGroundTruth(PredNode *node,
206                      const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
207                      const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
208   // If the current predicate is known to be true or false, change the kind of
209   // the node and return immediately.
210   if (knownTruePreds.count(node->predicate) != 0) {
211     node->kind = PredCombinerKind::True;
212     node->children.clear();
213     return node;
214   }
215   if (knownFalsePreds.count(node->predicate) != 0) {
216     node->kind = PredCombinerKind::False;
217     node->children.clear();
218     return node;
219   }
220 
221   // If the current node is a substitution, stop recursion now.
222   // The expressions in the leaves below this node were rewritten, but the nodes
223   // still point to the original predicate records.  While the original
224   // predicate may be known to be true or false, it is not necessarily the case
225   // after rewriting.
226   // TODO: we can support ground truth for rewritten
227   // predicates by either (a) having our own unique'ing of the predicates
228   // instead of relying on TableGen record pointers or (b) taking ground truth
229   // values optionally prefixed with a list of substitutions to apply, e.g.
230   // "predX is true by itself as well as predSubY leaf substitution had been
231   // applied to it".
232   if (node->kind == PredCombinerKind::SubstLeaves) {
233     return node;
234   }
235 
236   // Otherwise, look at child nodes.
237 
238   // Move child nodes into some local variable so that they can be optimized
239   // separately and re-added if necessary.
240   llvm::SmallVector<PredNode *, 4> children;
241   std::swap(node->children, children);
242 
243   for (auto &child : children) {
244     // First, simplify the child.  This maintains the predicate as it was.
245     auto *simplifiedChild =
246         propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
247 
248     // Just add the child if we don't know how to simplify the current node.
249     if (node->kind != PredCombinerKind::And &&
250         node->kind != PredCombinerKind::Or) {
251       node->children.push_back(simplifiedChild);
252       continue;
253     }
254 
255     // Second, based on the type define which known values of child predicates
256     // immediately collapse this predicate to a known value, and which others
257     // may be safely ignored.
258     //   OR(..., True, ...) = True
259     //   OR(..., False, ...) = OR(..., ...)
260     //   AND(..., False, ...) = False
261     //   AND(..., True, ...) = AND(..., ...)
262     auto collapseKind = node->kind == PredCombinerKind::And
263                             ? PredCombinerKind::False
264                             : PredCombinerKind::True;
265     auto eraseKind = node->kind == PredCombinerKind::And
266                          ? PredCombinerKind::True
267                          : PredCombinerKind::False;
268     const auto &collapseList =
269         node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
270     const auto &eraseList =
271         node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
272     if (simplifiedChild->kind == collapseKind ||
273         collapseList.count(simplifiedChild->predicate) != 0) {
274       node->kind = collapseKind;
275       node->children.clear();
276       return node;
277     }
278     if (simplifiedChild->kind == eraseKind ||
279         eraseList.count(simplifiedChild->predicate) != 0) {
280       continue;
281     }
282     node->children.push_back(simplifiedChild);
283   }
284   return node;
285 }
286 
287 // Combine a list of predicate expressions using a binary combiner.  If a list
288 // is empty, return "init".
combineBinary(ArrayRef<std::string> children,const std::string & combiner,std::string init)289 static std::string combineBinary(ArrayRef<std::string> children,
290                                  const std::string &combiner,
291                                  std::string init) {
292   if (children.empty())
293     return init;
294 
295   auto size = children.size();
296   if (size == 1)
297     return children.front();
298 
299   std::string str;
300   llvm::raw_string_ostream os(str);
301   os << '(' << children.front() << ')';
302   for (unsigned i = 1; i < size; ++i) {
303     os << ' ' << combiner << " (" << children[i] << ')';
304   }
305   return os.str();
306 }
307 
308 // Prepend negation to the only condition in the predicate expression list.
combineNot(ArrayRef<std::string> children)309 static std::string combineNot(ArrayRef<std::string> children) {
310   assert(children.size() == 1 && "expected exactly one child predicate of Neg");
311   return (Twine("!(") + children.front() + Twine(')')).str();
312 }
313 
314 // Recursively traverse the predicate tree in depth-first post-order and build
315 // the final expression.
getCombinedCondition(const PredNode & root)316 static std::string getCombinedCondition(const PredNode &root) {
317   // Immediately return for non-combiner predicates that don't have children.
318   if (root.kind == PredCombinerKind::Leaf)
319     return root.expr;
320   if (root.kind == PredCombinerKind::True)
321     return "true";
322   if (root.kind == PredCombinerKind::False)
323     return "false";
324 
325   // Recurse into children.
326   llvm::SmallVector<std::string, 4> childExpressions;
327   childExpressions.reserve(root.children.size());
328   for (const auto &child : root.children)
329     childExpressions.push_back(getCombinedCondition(*child));
330 
331   // Combine the expressions based on the predicate node kind.
332   if (root.kind == PredCombinerKind::And)
333     return combineBinary(childExpressions, "&&", "true");
334   if (root.kind == PredCombinerKind::Or)
335     return combineBinary(childExpressions, "||", "false");
336   if (root.kind == PredCombinerKind::Not)
337     return combineNot(childExpressions);
338   if (root.kind == PredCombinerKind::Concat) {
339     assert(childExpressions.size() == 1 &&
340            "ConcatPred should only have one child");
341     return root.prefix + childExpressions.front() + root.suffix;
342   }
343 
344   // Substitutions were applied before so just ignore them.
345   if (root.kind == PredCombinerKind::SubstLeaves) {
346     assert(childExpressions.size() == 1 &&
347            "substitution predicate must have one child");
348     return childExpressions[0];
349   }
350 
351   llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
352 }
353 
getConditionImpl() const354 std::string CombinedPred::getConditionImpl() const {
355   llvm::SpecificBumpPtrAllocator<PredNode> allocator;
356   auto *predicateTree = buildPredicateTree(*this, allocator, {});
357   predicateTree =
358       propagateGroundTruth(predicateTree,
359                            /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
360                            /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
361 
362   return getCombinedCondition(*predicateTree);
363 }
364 
getPattern() const365 StringRef SubstLeavesPred::getPattern() const {
366   return def->getValueAsString("pattern");
367 }
368 
getReplacement() const369 StringRef SubstLeavesPred::getReplacement() const {
370   return def->getValueAsString("replacement");
371 }
372 
getPrefix() const373 StringRef ConcatPred::getPrefix() const {
374   return def->getValueAsString("prefix");
375 }
376 
getSuffix() const377 StringRef ConcatPred::getSuffix() const {
378   return def->getValueAsString("suffix");
379 }
380