1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // Wrapper around predicates defined in TableGen.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/Predicate.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
29 
30 using namespace mlir;
31 
32 // Construct a Predicate from a record.
33 tblgen::Pred::Pred(const llvm::Record *record) : def(record) {
34   assert(def->isSubClassOf("Pred") &&
35          "must be a subclass of TableGen 'Pred' class");
36 }
37 
38 // Construct a Predicate from an initializer.
39 tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
40   if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
41     def = defInit->getDef();
42 }
43 
44 std::string tblgen::Pred::getCondition() const {
45   // Static dispatch to subclasses.
46   if (def->isSubClassOf("CombinedPred"))
47     return static_cast<const CombinedPred *>(this)->getConditionImpl();
48   if (def->isSubClassOf("CPred"))
49     return static_cast<const CPred *>(this)->getConditionImpl();
50   llvm_unreachable("Pred::getCondition must be overridden in subclasses");
51 }
52 
53 bool tblgen::Pred::isCombined() const {
54   return def && def->isSubClassOf("CombinedPred");
55 }
56 
57 ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); }
58 
59 tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) {
60   assert(def->isSubClassOf("CPred") &&
61          "must be a subclass of Tablegen 'CPred' class");
62 }
63 
64 tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
65   assert((!def || def->isSubClassOf("CPred")) &&
66          "must be a subclass of Tablegen 'CPred' class");
67 }
68 
69 // Get condition of the C Predicate.
70 std::string tblgen::CPred::getConditionImpl() const {
71   assert(!isNull() && "null predicate does not have a condition");
72   return def->getValueAsString("predExpr");
73 }
74 
75 tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
76   assert(def->isSubClassOf("CombinedPred") &&
77          "must be a subclass of Tablegen 'CombinedPred' class");
78 }
79 
80 tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
81   assert((!def || def->isSubClassOf("CombinedPred")) &&
82          "must be a subclass of Tablegen 'CombinedPred' class");
83 }
84 
85 const llvm::Record *tblgen::CombinedPred::getCombinerDef() const {
86   assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
87   return def->getValueAsDef("kind");
88 }
89 
90 const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const {
91   assert(def->getValue("children") &&
92          "CombinedPred must have a value 'children'");
93   return def->getValueAsListOfDefs("children");
94 }
95 
96 namespace {
97 // Kinds of nodes in a logical predicate tree.
98 enum class PredCombinerKind {
99   Leaf,
100   And,
101   Or,
102   Not,
103   SubstLeaves,
104   Concat,
105   // Special kinds that are used in simplification.
106   False,
107   True
108 };
109 
110 // A node in a logical predicate tree.
111 struct PredNode {
112   PredCombinerKind kind;
113   const tblgen::Pred *predicate;
114   SmallVector<PredNode *, 4> children;
115   std::string expr;
116 
117   // Prefix and suffix are used by ConcatPred.
118   std::string prefix;
119   std::string suffix;
120 };
121 } // end anonymous namespace
122 
123 // Get a predicate tree node kind based on the kind used in the predicate
124 // TableGen record.
125 static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) {
126   if (!pred.isCombined())
127     return PredCombinerKind::Leaf;
128 
129   const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred);
130   return llvm::StringSwitch<PredCombinerKind>(
131              combinedPred.getCombinerDef()->getName())
132       .Case("PredCombinerAnd", PredCombinerKind::And)
133       .Case("PredCombinerOr", PredCombinerKind::Or)
134       .Case("PredCombinerNot", PredCombinerKind::Not)
135       .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
136       .Case("PredCombinerConcat", PredCombinerKind::Concat);
137 }
138 
139 namespace {
140 // Substitution<pattern, replacement>.
141 using Subst = std::pair<StringRef, StringRef>;
142 } // end anonymous namespace
143 
144 // Build the predicate tree starting from the top-level predicate, which may
145 // have children, and perform leaf substitutions inplace.  Note that after
146 // substitution, nodes are still pointing to the original TableGen record.
147 // All nodes are created within "allocator".
148 static PredNode *buildPredicateTree(const tblgen::Pred &root,
149                                     llvm::BumpPtrAllocator &allocator,
150                                     ArrayRef<Subst> substitutions) {
151   auto *rootNode = allocator.Allocate<PredNode>();
152   new (rootNode) PredNode;
153   rootNode->kind = getPredCombinerKind(root);
154   rootNode->predicate = &root;
155   if (!root.isCombined()) {
156     rootNode->expr = root.getCondition();
157     // Apply all parent substitutions from innermost to outermost.
158     for (const auto &subst : llvm::reverse(substitutions)) {
159       auto pos = rootNode->expr.find(subst.first);
160       while (pos != std::string::npos) {
161         rootNode->expr.replace(pos, subst.first.size(), subst.second);
162         // Skip the newly inserted substring, which itself may consider the
163         // pattern to match.
164         pos += subst.second.size();
165         // Find the next possible match position.
166         pos = rootNode->expr.find(subst.first, pos);
167       }
168     }
169     return rootNode;
170   }
171 
172   // If the current combined predicate is a leaf substitution, append it to the
173   // list before continuing.
174   auto allSubstitutions = llvm::to_vector<4>(substitutions);
175   if (rootNode->kind == PredCombinerKind::SubstLeaves) {
176     const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root);
177     allSubstitutions.push_back(
178         {substPred.getPattern(), substPred.getReplacement()});
179   }
180   // If the current predicate is a ConcatPred, record the prefix and suffix.
181   else if (rootNode->kind == PredCombinerKind::Concat) {
182     const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root);
183     rootNode->prefix = concatPred.getPrefix();
184     rootNode->suffix = concatPred.getSuffix();
185   }
186 
187   // Build child subtrees.
188   auto combined = static_cast<const tblgen::CombinedPred &>(root);
189   for (const auto *record : combined.getChildren()) {
190     auto childTree =
191         buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions);
192     rootNode->children.push_back(childTree);
193   }
194   return rootNode;
195 }
196 
197 // Simplify a predicate tree rooted at "node" using the predicates that are
198 // known to be true(false).  For AND(OR) combined predicates, if any of the
199 // children is known to be false(true), the result is also false(true).
200 // Furthermore, for AND(OR) combined predicates, children that are known to be
201 // true(false) don't have to be checked dynamically.
202 static PredNode *propagateGroundTruth(
203     PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds,
204     const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) {
205   // If the current predicate is known to be true or false, change the kind of
206   // the node and return immediately.
207   if (knownTruePreds.count(node->predicate) != 0) {
208     node->kind = PredCombinerKind::True;
209     node->children.clear();
210     return node;
211   }
212   if (knownFalsePreds.count(node->predicate) != 0) {
213     node->kind = PredCombinerKind::False;
214     node->children.clear();
215     return node;
216   }
217 
218   // If the current node is a substitution, stop recursion now.
219   // The expressions in the leaves below this node were rewritten, but the nodes
220   // still point to the original predicate records.  While the original
221   // predicate may be known to be true or false, it is not necessarily the case
222   // after rewriting.
223   // TODO(zinenko,jpienaar): we can support ground truth for rewritten
224   // predicates by either (a) having our own unique'ing of the predicates
225   // instead of relying on TableGen record pointers or (b) taking ground truth
226   // values optionally prefixed with a list of substitutions to apply, e.g.
227   // "predX is true by itself as well as predSubY leaf substitution had been
228   // applied to it".
229   if (node->kind == PredCombinerKind::SubstLeaves) {
230     return node;
231   }
232 
233   // Otherwise, look at child nodes.
234 
235   // Move child nodes into some local variable so that they can be optimized
236   // separately and re-added if necessary.
237   llvm::SmallVector<PredNode *, 4> children;
238   std::swap(node->children, children);
239 
240   for (auto &child : children) {
241     // First, simplify the child.  This maintains the predicate as it was.
242     auto simplifiedChild =
243         propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
244 
245     // Just add the child if we don't know how to simplify the current node.
246     if (node->kind != PredCombinerKind::And &&
247         node->kind != PredCombinerKind::Or) {
248       node->children.push_back(simplifiedChild);
249       continue;
250     }
251 
252     // Second, based on the type define which known values of child predicates
253     // immediately collapse this predicate to a known value, and which others
254     // may be safely ignored.
255     //   OR(..., True, ...) = True
256     //   OR(..., False, ...) = OR(..., ...)
257     //   AND(..., False, ...) = False
258     //   AND(..., True, ...) = AND(..., ...)
259     auto collapseKind = node->kind == PredCombinerKind::And
260                             ? PredCombinerKind::False
261                             : PredCombinerKind::True;
262     auto eraseKind = node->kind == PredCombinerKind::And
263                          ? PredCombinerKind::True
264                          : PredCombinerKind::False;
265     const auto &collapseList =
266         node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
267     const auto &eraseList =
268         node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
269     if (simplifiedChild->kind == collapseKind ||
270         collapseList.count(simplifiedChild->predicate) != 0) {
271       node->kind = collapseKind;
272       node->children.clear();
273       return node;
274     } else if (simplifiedChild->kind == eraseKind ||
275                eraseList.count(simplifiedChild->predicate) != 0) {
276       continue;
277     }
278     node->children.push_back(simplifiedChild);
279   }
280   return node;
281 }
282 
283 // Combine a list of predicate expressions using a binary combiner.  If a list
284 // is empty, return "init".
285 static std::string combineBinary(ArrayRef<std::string> children,
286                                  std::string combiner, std::string init) {
287   if (children.empty())
288     return init;
289 
290   auto size = children.size();
291   if (size == 1)
292     return children.front();
293 
294   std::string str;
295   llvm::raw_string_ostream os(str);
296   os << '(' << children.front() << ')';
297   for (unsigned i = 1; i < size; ++i) {
298     os << ' ' << combiner << " (" << children[i] << ')';
299   }
300   return os.str();
301 }
302 
303 // Prepend negation to the only condition in the predicate expression list.
304 static std::string combineNot(ArrayRef<std::string> children) {
305   assert(children.size() == 1 && "expected exactly one child predicate of Neg");
306   return (Twine("!(") + children.front() + Twine(')')).str();
307 }
308 
309 // Recursively traverse the predicate tree in depth-first post-order and build
310 // the final expression.
311 static std::string getCombinedCondition(const PredNode &root) {
312   // Immediately return for non-combiner predicates that don't have children.
313   if (root.kind == PredCombinerKind::Leaf)
314     return root.expr;
315   if (root.kind == PredCombinerKind::True)
316     return "true";
317   if (root.kind == PredCombinerKind::False)
318     return "false";
319 
320   // Recurse into children.
321   llvm::SmallVector<std::string, 4> childExpressions;
322   childExpressions.reserve(root.children.size());
323   for (const auto &child : root.children)
324     childExpressions.push_back(getCombinedCondition(*child));
325 
326   // Combine the expressions based on the predicate node kind.
327   if (root.kind == PredCombinerKind::And)
328     return combineBinary(childExpressions, "&&", "true");
329   if (root.kind == PredCombinerKind::Or)
330     return combineBinary(childExpressions, "||", "false");
331   if (root.kind == PredCombinerKind::Not)
332     return combineNot(childExpressions);
333   if (root.kind == PredCombinerKind::Concat) {
334     assert(childExpressions.size() == 1 &&
335            "ConcatPred should only have one child");
336     return root.prefix + childExpressions.front() + root.suffix;
337   }
338 
339   // Substitutions were applied before so just ignore them.
340   if (root.kind == PredCombinerKind::SubstLeaves) {
341     assert(childExpressions.size() == 1 &&
342            "substitution predicate must have one child");
343     return childExpressions[0];
344   }
345 
346   llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
347 }
348 
349 std::string tblgen::CombinedPred::getConditionImpl() const {
350   llvm::BumpPtrAllocator allocator;
351   auto predicateTree = buildPredicateTree(*this, allocator, {});
352   predicateTree = propagateGroundTruth(
353       predicateTree,
354       /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(),
355       /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>());
356 
357   return getCombinedCondition(*predicateTree);
358 }
359 
360 StringRef tblgen::SubstLeavesPred::getPattern() const {
361   return def->getValueAsString("pattern");
362 }
363 
364 StringRef tblgen::SubstLeavesPred::getReplacement() const {
365   return def->getValueAsString("replacement");
366 }
367 
368 StringRef tblgen::ConcatPred::getPrefix() const {
369   return def->getValueAsString("prefix");
370 }
371 
372 StringRef tblgen::ConcatPred::getSuffix() const {
373   return def->getValueAsString("suffix");
374 }
375