1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace mlir;
22 
23 // Construct a Predicate from a record.
24 tblgen::Pred::Pred(const llvm::Record *record) : def(record) {
25   assert(def->isSubClassOf("Pred") &&
26          "must be a subclass of TableGen 'Pred' class");
27 }
28 
29 // Construct a Predicate from an initializer.
30 tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
31   if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
32     def = defInit->getDef();
33 }
34 
35 std::string tblgen::Pred::getCondition() const {
36   // Static dispatch to subclasses.
37   if (def->isSubClassOf("CombinedPred"))
38     return static_cast<const CombinedPred *>(this)->getConditionImpl();
39   if (def->isSubClassOf("CPred"))
40     return static_cast<const CPred *>(this)->getConditionImpl();
41   llvm_unreachable("Pred::getCondition must be overridden in subclasses");
42 }
43 
44 bool tblgen::Pred::isCombined() const {
45   return def && def->isSubClassOf("CombinedPred");
46 }
47 
48 ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); }
49 
50 tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) {
51   assert(def->isSubClassOf("CPred") &&
52          "must be a subclass of Tablegen 'CPred' class");
53 }
54 
55 tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
56   assert((!def || def->isSubClassOf("CPred")) &&
57          "must be a subclass of Tablegen 'CPred' class");
58 }
59 
60 // Get condition of the C Predicate.
61 std::string tblgen::CPred::getConditionImpl() const {
62   assert(!isNull() && "null predicate does not have a condition");
63   return std::string(def->getValueAsString("predExpr"));
64 }
65 
66 tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
67   assert(def->isSubClassOf("CombinedPred") &&
68          "must be a subclass of Tablegen 'CombinedPred' class");
69 }
70 
71 tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
72   assert((!def || def->isSubClassOf("CombinedPred")) &&
73          "must be a subclass of Tablegen 'CombinedPred' class");
74 }
75 
76 const llvm::Record *tblgen::CombinedPred::getCombinerDef() const {
77   assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
78   return def->getValueAsDef("kind");
79 }
80 
81 const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const {
82   assert(def->getValue("children") &&
83          "CombinedPred must have a value 'children'");
84   return def->getValueAsListOfDefs("children");
85 }
86 
87 namespace {
88 // Kinds of nodes in a logical predicate tree.
89 enum class PredCombinerKind {
90   Leaf,
91   And,
92   Or,
93   Not,
94   SubstLeaves,
95   Concat,
96   // Special kinds that are used in simplification.
97   False,
98   True
99 };
100 
101 // A node in a logical predicate tree.
102 struct PredNode {
103   PredCombinerKind kind;
104   const tblgen::Pred *predicate;
105   SmallVector<PredNode *, 4> children;
106   std::string expr;
107 
108   // Prefix and suffix are used by ConcatPred.
109   std::string prefix;
110   std::string suffix;
111 };
112 } // end anonymous namespace
113 
114 // Get a predicate tree node kind based on the kind used in the predicate
115 // TableGen record.
116 static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) {
117   if (!pred.isCombined())
118     return PredCombinerKind::Leaf;
119 
120   const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred);
121   return llvm::StringSwitch<PredCombinerKind>(
122              combinedPred.getCombinerDef()->getName())
123       .Case("PredCombinerAnd", PredCombinerKind::And)
124       .Case("PredCombinerOr", PredCombinerKind::Or)
125       .Case("PredCombinerNot", PredCombinerKind::Not)
126       .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
127       .Case("PredCombinerConcat", PredCombinerKind::Concat);
128 }
129 
130 namespace {
131 // Substitution<pattern, replacement>.
132 using Subst = std::pair<StringRef, StringRef>;
133 } // end anonymous namespace
134 
135 // Build the predicate tree starting from the top-level predicate, which may
136 // have children, and perform leaf substitutions inplace.  Note that after
137 // substitution, nodes are still pointing to the original TableGen record.
138 // All nodes are created within "allocator".
139 static PredNode *
140 buildPredicateTree(const tblgen::Pred &root,
141                    llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
142                    ArrayRef<Subst> substitutions) {
143   auto *rootNode = allocator.Allocate();
144   new (rootNode) PredNode;
145   rootNode->kind = getPredCombinerKind(root);
146   rootNode->predicate = &root;
147   if (!root.isCombined()) {
148     rootNode->expr = root.getCondition();
149     // Apply all parent substitutions from innermost to outermost.
150     for (const auto &subst : llvm::reverse(substitutions)) {
151       auto pos = rootNode->expr.find(std::string(subst.first));
152       while (pos != std::string::npos) {
153         rootNode->expr.replace(pos, subst.first.size(),
154                                std::string(subst.second));
155         // Skip the newly inserted substring, which itself may consider the
156         // pattern to match.
157         pos += subst.second.size();
158         // Find the next possible match position.
159         pos = rootNode->expr.find(std::string(subst.first), pos);
160       }
161     }
162     return rootNode;
163   }
164 
165   // If the current combined predicate is a leaf substitution, append it to the
166   // list before continuing.
167   auto allSubstitutions = llvm::to_vector<4>(substitutions);
168   if (rootNode->kind == PredCombinerKind::SubstLeaves) {
169     const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root);
170     allSubstitutions.push_back(
171         {substPred.getPattern(), substPred.getReplacement()});
172   }
173   // If the current predicate is a ConcatPred, record the prefix and suffix.
174   else if (rootNode->kind == PredCombinerKind::Concat) {
175     const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root);
176     rootNode->prefix = std::string(concatPred.getPrefix());
177     rootNode->suffix = std::string(concatPred.getSuffix());
178   }
179 
180   // Build child subtrees.
181   auto combined = static_cast<const tblgen::CombinedPred &>(root);
182   for (const auto *record : combined.getChildren()) {
183     auto childTree =
184         buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions);
185     rootNode->children.push_back(childTree);
186   }
187   return rootNode;
188 }
189 
190 // Simplify a predicate tree rooted at "node" using the predicates that are
191 // known to be true(false).  For AND(OR) combined predicates, if any of the
192 // children is known to be false(true), the result is also false(true).
193 // Furthermore, for AND(OR) combined predicates, children that are known to be
194 // true(false) don't have to be checked dynamically.
195 static PredNode *propagateGroundTruth(
196     PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds,
197     const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) {
198   // If the current predicate is known to be true or false, change the kind of
199   // the node and return immediately.
200   if (knownTruePreds.count(node->predicate) != 0) {
201     node->kind = PredCombinerKind::True;
202     node->children.clear();
203     return node;
204   }
205   if (knownFalsePreds.count(node->predicate) != 0) {
206     node->kind = PredCombinerKind::False;
207     node->children.clear();
208     return node;
209   }
210 
211   // If the current node is a substitution, stop recursion now.
212   // The expressions in the leaves below this node were rewritten, but the nodes
213   // still point to the original predicate records.  While the original
214   // predicate may be known to be true or false, it is not necessarily the case
215   // after rewriting.
216   // TODO(zinenko,jpienaar): we can support ground truth for rewritten
217   // predicates by either (a) having our own unique'ing of the predicates
218   // instead of relying on TableGen record pointers or (b) taking ground truth
219   // values optionally prefixed with a list of substitutions to apply, e.g.
220   // "predX is true by itself as well as predSubY leaf substitution had been
221   // applied to it".
222   if (node->kind == PredCombinerKind::SubstLeaves) {
223     return node;
224   }
225 
226   // Otherwise, look at child nodes.
227 
228   // Move child nodes into some local variable so that they can be optimized
229   // separately and re-added if necessary.
230   llvm::SmallVector<PredNode *, 4> children;
231   std::swap(node->children, children);
232 
233   for (auto &child : children) {
234     // First, simplify the child.  This maintains the predicate as it was.
235     auto simplifiedChild =
236         propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
237 
238     // Just add the child if we don't know how to simplify the current node.
239     if (node->kind != PredCombinerKind::And &&
240         node->kind != PredCombinerKind::Or) {
241       node->children.push_back(simplifiedChild);
242       continue;
243     }
244 
245     // Second, based on the type define which known values of child predicates
246     // immediately collapse this predicate to a known value, and which others
247     // may be safely ignored.
248     //   OR(..., True, ...) = True
249     //   OR(..., False, ...) = OR(..., ...)
250     //   AND(..., False, ...) = False
251     //   AND(..., True, ...) = AND(..., ...)
252     auto collapseKind = node->kind == PredCombinerKind::And
253                             ? PredCombinerKind::False
254                             : PredCombinerKind::True;
255     auto eraseKind = node->kind == PredCombinerKind::And
256                          ? PredCombinerKind::True
257                          : PredCombinerKind::False;
258     const auto &collapseList =
259         node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
260     const auto &eraseList =
261         node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
262     if (simplifiedChild->kind == collapseKind ||
263         collapseList.count(simplifiedChild->predicate) != 0) {
264       node->kind = collapseKind;
265       node->children.clear();
266       return node;
267     } else if (simplifiedChild->kind == eraseKind ||
268                eraseList.count(simplifiedChild->predicate) != 0) {
269       continue;
270     }
271     node->children.push_back(simplifiedChild);
272   }
273   return node;
274 }
275 
276 // Combine a list of predicate expressions using a binary combiner.  If a list
277 // is empty, return "init".
278 static std::string combineBinary(ArrayRef<std::string> children,
279                                  std::string combiner, std::string init) {
280   if (children.empty())
281     return init;
282 
283   auto size = children.size();
284   if (size == 1)
285     return children.front();
286 
287   std::string str;
288   llvm::raw_string_ostream os(str);
289   os << '(' << children.front() << ')';
290   for (unsigned i = 1; i < size; ++i) {
291     os << ' ' << combiner << " (" << children[i] << ')';
292   }
293   return os.str();
294 }
295 
296 // Prepend negation to the only condition in the predicate expression list.
297 static std::string combineNot(ArrayRef<std::string> children) {
298   assert(children.size() == 1 && "expected exactly one child predicate of Neg");
299   return (Twine("!(") + children.front() + Twine(')')).str();
300 }
301 
302 // Recursively traverse the predicate tree in depth-first post-order and build
303 // the final expression.
304 static std::string getCombinedCondition(const PredNode &root) {
305   // Immediately return for non-combiner predicates that don't have children.
306   if (root.kind == PredCombinerKind::Leaf)
307     return root.expr;
308   if (root.kind == PredCombinerKind::True)
309     return "true";
310   if (root.kind == PredCombinerKind::False)
311     return "false";
312 
313   // Recurse into children.
314   llvm::SmallVector<std::string, 4> childExpressions;
315   childExpressions.reserve(root.children.size());
316   for (const auto &child : root.children)
317     childExpressions.push_back(getCombinedCondition(*child));
318 
319   // Combine the expressions based on the predicate node kind.
320   if (root.kind == PredCombinerKind::And)
321     return combineBinary(childExpressions, "&&", "true");
322   if (root.kind == PredCombinerKind::Or)
323     return combineBinary(childExpressions, "||", "false");
324   if (root.kind == PredCombinerKind::Not)
325     return combineNot(childExpressions);
326   if (root.kind == PredCombinerKind::Concat) {
327     assert(childExpressions.size() == 1 &&
328            "ConcatPred should only have one child");
329     return root.prefix + childExpressions.front() + root.suffix;
330   }
331 
332   // Substitutions were applied before so just ignore them.
333   if (root.kind == PredCombinerKind::SubstLeaves) {
334     assert(childExpressions.size() == 1 &&
335            "substitution predicate must have one child");
336     return childExpressions[0];
337   }
338 
339   llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
340 }
341 
342 std::string tblgen::CombinedPred::getConditionImpl() const {
343   llvm::SpecificBumpPtrAllocator<PredNode> allocator;
344   auto predicateTree = buildPredicateTree(*this, allocator, {});
345   predicateTree = propagateGroundTruth(
346       predicateTree,
347       /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(),
348       /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>());
349 
350   return getCombinedCondition(*predicateTree);
351 }
352 
353 StringRef tblgen::SubstLeavesPred::getPattern() const {
354   return def->getValueAsString("pattern");
355 }
356 
357 StringRef tblgen::SubstLeavesPred::getReplacement() const {
358   return def->getValueAsString("replacement");
359 }
360 
361 StringRef tblgen::ConcatPred::getPrefix() const {
362   return def->getValueAsString("prefix");
363 }
364 
365 StringRef tblgen::ConcatPred::getSuffix() const {
366   return def->getValueAsString("suffix");
367 }
368