1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/TableGen/Predicate.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 using namespace mlir;
23 using namespace tblgen;
24
25 // Construct a Predicate from a record.
Pred(const llvm::Record * record)26 Pred::Pred(const llvm::Record *record) : def(record) {
27 assert(def->isSubClassOf("Pred") &&
28 "must be a subclass of TableGen 'Pred' class");
29 }
30
31 // Construct a Predicate from an initializer.
Pred(const llvm::Init * init)32 Pred::Pred(const llvm::Init *init) {
33 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
34 def = defInit->getDef();
35 }
36
getCondition() const37 std::string Pred::getCondition() const {
38 // Static dispatch to subclasses.
39 if (def->isSubClassOf("CombinedPred"))
40 return static_cast<const CombinedPred *>(this)->getConditionImpl();
41 if (def->isSubClassOf("CPred"))
42 return static_cast<const CPred *>(this)->getConditionImpl();
43 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
44 }
45
isCombined() const46 bool Pred::isCombined() const {
47 return def && def->isSubClassOf("CombinedPred");
48 }
49
getLoc() const50 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
51
CPred(const llvm::Record * record)52 CPred::CPred(const llvm::Record *record) : Pred(record) {
53 assert(def->isSubClassOf("CPred") &&
54 "must be a subclass of Tablegen 'CPred' class");
55 }
56
CPred(const llvm::Init * init)57 CPred::CPred(const llvm::Init *init) : Pred(init) {
58 assert((!def || def->isSubClassOf("CPred")) &&
59 "must be a subclass of Tablegen 'CPred' class");
60 }
61
62 // Get condition of the C Predicate.
getConditionImpl() const63 std::string CPred::getConditionImpl() const {
64 assert(!isNull() && "null predicate does not have a condition");
65 return std::string(def->getValueAsString("predExpr"));
66 }
67
CombinedPred(const llvm::Record * record)68 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
69 assert(def->isSubClassOf("CombinedPred") &&
70 "must be a subclass of Tablegen 'CombinedPred' class");
71 }
72
CombinedPred(const llvm::Init * init)73 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
74 assert((!def || def->isSubClassOf("CombinedPred")) &&
75 "must be a subclass of Tablegen 'CombinedPred' class");
76 }
77
getCombinerDef() const78 const llvm::Record *CombinedPred::getCombinerDef() const {
79 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
80 return def->getValueAsDef("kind");
81 }
82
getChildren() const83 std::vector<llvm::Record *> CombinedPred::getChildren() const {
84 assert(def->getValue("children") &&
85 "CombinedPred must have a value 'children'");
86 return def->getValueAsListOfDefs("children");
87 }
88
89 namespace {
90 // Kinds of nodes in a logical predicate tree.
91 enum class PredCombinerKind {
92 Leaf,
93 And,
94 Or,
95 Not,
96 SubstLeaves,
97 Concat,
98 // Special kinds that are used in simplification.
99 False,
100 True
101 };
102
103 // A node in a logical predicate tree.
104 struct PredNode {
105 PredCombinerKind kind;
106 const Pred *predicate;
107 SmallVector<PredNode *, 4> children;
108 std::string expr;
109
110 // Prefix and suffix are used by ConcatPred.
111 std::string prefix;
112 std::string suffix;
113 };
114 } // namespace
115
116 // Get a predicate tree node kind based on the kind used in the predicate
117 // TableGen record.
getPredCombinerKind(const Pred & pred)118 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
119 if (!pred.isCombined())
120 return PredCombinerKind::Leaf;
121
122 const auto &combinedPred = static_cast<const CombinedPred &>(pred);
123 return StringSwitch<PredCombinerKind>(
124 combinedPred.getCombinerDef()->getName())
125 .Case("PredCombinerAnd", PredCombinerKind::And)
126 .Case("PredCombinerOr", PredCombinerKind::Or)
127 .Case("PredCombinerNot", PredCombinerKind::Not)
128 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
129 .Case("PredCombinerConcat", PredCombinerKind::Concat);
130 }
131
132 namespace {
133 // Substitution<pattern, replacement>.
134 using Subst = std::pair<StringRef, StringRef>;
135 } // namespace
136
137 /// Perform the given substitutions on 'str' in-place.
performSubstitutions(std::string & str,ArrayRef<Subst> substitutions)138 static void performSubstitutions(std::string &str,
139 ArrayRef<Subst> substitutions) {
140 // Apply all parent substitutions from innermost to outermost.
141 for (const auto &subst : llvm::reverse(substitutions)) {
142 auto pos = str.find(std::string(subst.first));
143 while (pos != std::string::npos) {
144 str.replace(pos, subst.first.size(), std::string(subst.second));
145 // Skip the newly inserted substring, which itself may consider the
146 // pattern to match.
147 pos += subst.second.size();
148 // Find the next possible match position.
149 pos = str.find(std::string(subst.first), pos);
150 }
151 }
152 }
153
154 // Build the predicate tree starting from the top-level predicate, which may
155 // have children, and perform leaf substitutions inplace. Note that after
156 // substitution, nodes are still pointing to the original TableGen record.
157 // All nodes are created within "allocator".
158 static PredNode *
buildPredicateTree(const Pred & root,llvm::SpecificBumpPtrAllocator<PredNode> & allocator,ArrayRef<Subst> substitutions)159 buildPredicateTree(const Pred &root,
160 llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
161 ArrayRef<Subst> substitutions) {
162 auto *rootNode = allocator.Allocate();
163 new (rootNode) PredNode;
164 rootNode->kind = getPredCombinerKind(root);
165 rootNode->predicate = &root;
166 if (!root.isCombined()) {
167 rootNode->expr = root.getCondition();
168 performSubstitutions(rootNode->expr, substitutions);
169 return rootNode;
170 }
171
172 // If the current combined predicate is a leaf substitution, append it to the
173 // list before continuing.
174 auto allSubstitutions = llvm::to_vector<4>(substitutions);
175 if (rootNode->kind == PredCombinerKind::SubstLeaves) {
176 const auto &substPred = static_cast<const SubstLeavesPred &>(root);
177 allSubstitutions.push_back(
178 {substPred.getPattern(), substPred.getReplacement()});
179
180 // If the current predicate is a ConcatPred, record the prefix and suffix.
181 } else if (rootNode->kind == PredCombinerKind::Concat) {
182 const auto &concatPred = static_cast<const ConcatPred &>(root);
183 rootNode->prefix = std::string(concatPred.getPrefix());
184 performSubstitutions(rootNode->prefix, substitutions);
185 rootNode->suffix = std::string(concatPred.getSuffix());
186 performSubstitutions(rootNode->suffix, substitutions);
187 }
188
189 // Build child subtrees.
190 auto combined = static_cast<const CombinedPred &>(root);
191 for (const auto *record : combined.getChildren()) {
192 auto *childTree =
193 buildPredicateTree(Pred(record), allocator, allSubstitutions);
194 rootNode->children.push_back(childTree);
195 }
196 return rootNode;
197 }
198
199 // Simplify a predicate tree rooted at "node" using the predicates that are
200 // known to be true(false). For AND(OR) combined predicates, if any of the
201 // children is known to be false(true), the result is also false(true).
202 // Furthermore, for AND(OR) combined predicates, children that are known to be
203 // true(false) don't have to be checked dynamically.
204 static PredNode *
propagateGroundTruth(PredNode * node,const llvm::SmallPtrSetImpl<Pred * > & knownTruePreds,const llvm::SmallPtrSetImpl<Pred * > & knownFalsePreds)205 propagateGroundTruth(PredNode *node,
206 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
207 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
208 // If the current predicate is known to be true or false, change the kind of
209 // the node and return immediately.
210 if (knownTruePreds.count(node->predicate) != 0) {
211 node->kind = PredCombinerKind::True;
212 node->children.clear();
213 return node;
214 }
215 if (knownFalsePreds.count(node->predicate) != 0) {
216 node->kind = PredCombinerKind::False;
217 node->children.clear();
218 return node;
219 }
220
221 // If the current node is a substitution, stop recursion now.
222 // The expressions in the leaves below this node were rewritten, but the nodes
223 // still point to the original predicate records. While the original
224 // predicate may be known to be true or false, it is not necessarily the case
225 // after rewriting.
226 // TODO: we can support ground truth for rewritten
227 // predicates by either (a) having our own unique'ing of the predicates
228 // instead of relying on TableGen record pointers or (b) taking ground truth
229 // values optionally prefixed with a list of substitutions to apply, e.g.
230 // "predX is true by itself as well as predSubY leaf substitution had been
231 // applied to it".
232 if (node->kind == PredCombinerKind::SubstLeaves) {
233 return node;
234 }
235
236 // Otherwise, look at child nodes.
237
238 // Move child nodes into some local variable so that they can be optimized
239 // separately and re-added if necessary.
240 llvm::SmallVector<PredNode *, 4> children;
241 std::swap(node->children, children);
242
243 for (auto &child : children) {
244 // First, simplify the child. This maintains the predicate as it was.
245 auto *simplifiedChild =
246 propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
247
248 // Just add the child if we don't know how to simplify the current node.
249 if (node->kind != PredCombinerKind::And &&
250 node->kind != PredCombinerKind::Or) {
251 node->children.push_back(simplifiedChild);
252 continue;
253 }
254
255 // Second, based on the type define which known values of child predicates
256 // immediately collapse this predicate to a known value, and which others
257 // may be safely ignored.
258 // OR(..., True, ...) = True
259 // OR(..., False, ...) = OR(..., ...)
260 // AND(..., False, ...) = False
261 // AND(..., True, ...) = AND(..., ...)
262 auto collapseKind = node->kind == PredCombinerKind::And
263 ? PredCombinerKind::False
264 : PredCombinerKind::True;
265 auto eraseKind = node->kind == PredCombinerKind::And
266 ? PredCombinerKind::True
267 : PredCombinerKind::False;
268 const auto &collapseList =
269 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
270 const auto &eraseList =
271 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
272 if (simplifiedChild->kind == collapseKind ||
273 collapseList.count(simplifiedChild->predicate) != 0) {
274 node->kind = collapseKind;
275 node->children.clear();
276 return node;
277 }
278 if (simplifiedChild->kind == eraseKind ||
279 eraseList.count(simplifiedChild->predicate) != 0) {
280 continue;
281 }
282 node->children.push_back(simplifiedChild);
283 }
284 return node;
285 }
286
287 // Combine a list of predicate expressions using a binary combiner. If a list
288 // is empty, return "init".
combineBinary(ArrayRef<std::string> children,const std::string & combiner,std::string init)289 static std::string combineBinary(ArrayRef<std::string> children,
290 const std::string &combiner,
291 std::string init) {
292 if (children.empty())
293 return init;
294
295 auto size = children.size();
296 if (size == 1)
297 return children.front();
298
299 std::string str;
300 llvm::raw_string_ostream os(str);
301 os << '(' << children.front() << ')';
302 for (unsigned i = 1; i < size; ++i) {
303 os << ' ' << combiner << " (" << children[i] << ')';
304 }
305 return os.str();
306 }
307
308 // Prepend negation to the only condition in the predicate expression list.
combineNot(ArrayRef<std::string> children)309 static std::string combineNot(ArrayRef<std::string> children) {
310 assert(children.size() == 1 && "expected exactly one child predicate of Neg");
311 return (Twine("!(") + children.front() + Twine(')')).str();
312 }
313
314 // Recursively traverse the predicate tree in depth-first post-order and build
315 // the final expression.
getCombinedCondition(const PredNode & root)316 static std::string getCombinedCondition(const PredNode &root) {
317 // Immediately return for non-combiner predicates that don't have children.
318 if (root.kind == PredCombinerKind::Leaf)
319 return root.expr;
320 if (root.kind == PredCombinerKind::True)
321 return "true";
322 if (root.kind == PredCombinerKind::False)
323 return "false";
324
325 // Recurse into children.
326 llvm::SmallVector<std::string, 4> childExpressions;
327 childExpressions.reserve(root.children.size());
328 for (const auto &child : root.children)
329 childExpressions.push_back(getCombinedCondition(*child));
330
331 // Combine the expressions based on the predicate node kind.
332 if (root.kind == PredCombinerKind::And)
333 return combineBinary(childExpressions, "&&", "true");
334 if (root.kind == PredCombinerKind::Or)
335 return combineBinary(childExpressions, "||", "false");
336 if (root.kind == PredCombinerKind::Not)
337 return combineNot(childExpressions);
338 if (root.kind == PredCombinerKind::Concat) {
339 assert(childExpressions.size() == 1 &&
340 "ConcatPred should only have one child");
341 return root.prefix + childExpressions.front() + root.suffix;
342 }
343
344 // Substitutions were applied before so just ignore them.
345 if (root.kind == PredCombinerKind::SubstLeaves) {
346 assert(childExpressions.size() == 1 &&
347 "substitution predicate must have one child");
348 return childExpressions[0];
349 }
350
351 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
352 }
353
getConditionImpl() const354 std::string CombinedPred::getConditionImpl() const {
355 llvm::SpecificBumpPtrAllocator<PredNode> allocator;
356 auto *predicateTree = buildPredicateTree(*this, allocator, {});
357 predicateTree =
358 propagateGroundTruth(predicateTree,
359 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
360 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
361
362 return getCombinedCondition(*predicateTree);
363 }
364
getPattern() const365 StringRef SubstLeavesPred::getPattern() const {
366 return def->getValueAsString("pattern");
367 }
368
getReplacement() const369 StringRef SubstLeavesPred::getReplacement() const {
370 return def->getValueAsString("replacement");
371 }
372
getPrefix() const373 StringRef ConcatPred::getPrefix() const {
374 return def->getValueAsString("prefix");
375 }
376
getSuffix() const377 StringRef ConcatPred::getSuffix() const {
378 return def->getValueAsString("suffix");
379 }
380