1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21
22 #define DEBUG_TYPE "pdl-predicate-tree"
23
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32 Value val, PredicateBuilder &builder,
33 DenseMap<Value, Position *> &inputs,
34 Position *pos);
35
36 /// Compares the depths of two positions.
comparePosDepth(Position * lhs,Position * rhs)37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38 return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40
41 /// Returns the number of non-range elements within `values`.
getNumNonRangeValues(ValueRange values)42 static unsigned getNumNonRangeValues(ValueRange values) {
43 return llvm::count_if(values.getTypes(),
44 [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,AttributePosition * pos)47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48 Value val, PredicateBuilder &builder,
49 DenseMap<Value, Position *> &inputs,
50 AttributePosition *pos) {
51 assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52 pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53 predList.emplace_back(pos, builder.getIsNotNull());
54
55 // If the attribute has a type or value, add a constraint.
56 if (Value type = attr.type())
57 getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58 else if (Attribute value = attr.valueAttr())
59 predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61
62 /// Collect all of the predicates for the given operand position.
getOperandTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,Position * pos)63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64 Value val, PredicateBuilder &builder,
65 DenseMap<Value, Position *> &inputs,
66 Position *pos) {
67 Type valueType = val.getType();
68 bool isVariadic = valueType.isa<pdl::RangeType>();
69
70 // If this is a typed operand, add a type constraint.
71 TypeSwitch<Operation *>(val.getDefiningOp())
72 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73 // Prevent traversal into a null value if the operand has a proper
74 // index.
75 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77 predList.emplace_back(pos, builder.getIsNotNull());
78
79 if (Value type = op.type())
80 getTreePredicates(predList, type, builder, inputs,
81 builder.getType(pos));
82 })
83 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84 Optional<unsigned> index = op.index();
85
86 // Prevent traversal into a null value if the result has a proper index.
87 if (index)
88 predList.emplace_back(pos, builder.getIsNotNull());
89
90 // Get the parent operation of this operand.
91 OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92 predList.emplace_back(parentPos, builder.getIsNotNull());
93
94 // Ensure that the operands match the corresponding results of the
95 // parent operation.
96 Position *resultPos = nullptr;
97 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98 resultPos = builder.getResult(parentPos, *index);
99 else
100 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101 predList.emplace_back(resultPos, builder.getEqualTo(pos));
102
103 // Collect the predicates of the parent operation.
104 getTreePredicates(predList, op.parent(), builder, inputs,
105 (Position *)parentPos);
106 });
107 }
108
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,OperationPosition * pos,Optional<unsigned> ignoreOperand=llvm::None)109 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110 Value val, PredicateBuilder &builder,
111 DenseMap<Value, Position *> &inputs,
112 OperationPosition *pos,
113 Optional<unsigned> ignoreOperand = llvm::None) {
114 assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116 OperationPosition *opPos = cast<OperationPosition>(pos);
117
118 // Ensure getDefiningOp returns a non-null operation.
119 if (!opPos->isRoot())
120 predList.emplace_back(pos, builder.getIsNotNull());
121
122 // Check that this is the correct root operation.
123 if (Optional<StringRef> opName = op.name())
124 predList.emplace_back(pos, builder.getOperationName(*opName));
125
126 // Check that the operation has the proper number of operands. If there are
127 // any variable length operands, we check a minimum instead of an exact count.
128 OperandRange operands = op.operands();
129 unsigned minOperands = getNumNonRangeValues(operands);
130 if (minOperands != operands.size()) {
131 if (minOperands)
132 predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133 } else {
134 predList.emplace_back(pos, builder.getOperandCount(minOperands));
135 }
136
137 // Check that the operation has the proper number of results. If there are
138 // any variable length results, we check a minimum instead of an exact count.
139 OperandRange types = op.types();
140 unsigned minResults = getNumNonRangeValues(types);
141 if (minResults == types.size())
142 predList.emplace_back(pos, builder.getResultCount(types.size()));
143 else if (minResults)
144 predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145
146 // Recurse into any attributes, operands, or results.
147 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
148 getTreePredicates(
149 predList, std::get<1>(it), builder, inputs,
150 builder.getAttribute(opPos,
151 std::get<0>(it).cast<StringAttr>().getValue()));
152 }
153
154 // Process the operands and results of the operation. For all values up to
155 // the first variable length value, we use the concrete operand/result
156 // number. After that, we use the "group" given that we can't know the
157 // concrete indices until runtime. If there is only one variadic operand
158 // group, we treat it as all of the operands/results of the operation.
159 /// Operands.
160 if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161 // Ignore the operands if we are performing an upward traversal (in that
162 // case, they have already been visited).
163 if (opPos->isRoot() || opPos->isOperandDefiningOp())
164 getTreePredicates(predList, operands.front(), builder, inputs,
165 builder.getAllOperands(opPos));
166 } else {
167 bool foundVariableLength = false;
168 for (const auto &operandIt : llvm::enumerate(operands)) {
169 bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
170 foundVariableLength |= isVariadic;
171
172 // Ignore the specified operand, usually because this position was
173 // visited in an upward traversal via an iterative choice.
174 if (ignoreOperand && *ignoreOperand == operandIt.index())
175 continue;
176
177 Position *pos =
178 foundVariableLength
179 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180 : builder.getOperand(opPos, operandIt.index());
181 getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182 }
183 }
184 /// Results.
185 if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
186 getTreePredicates(predList, types.front(), builder, inputs,
187 builder.getType(builder.getAllResults(opPos)));
188 } else {
189 bool foundVariableLength = false;
190 for (auto &resultIt : llvm::enumerate(types)) {
191 bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
192 foundVariableLength |= isVariadic;
193
194 auto *resultPos =
195 foundVariableLength
196 ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
197 : builder.getResult(pos, resultIt.index());
198 predList.emplace_back(resultPos, builder.getIsNotNull());
199 getTreePredicates(predList, resultIt.value(), builder, inputs,
200 builder.getType(resultPos));
201 }
202 }
203 }
204
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,TypePosition * pos)205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206 Value val, PredicateBuilder &builder,
207 DenseMap<Value, Position *> &inputs,
208 TypePosition *pos) {
209 // Check for a constraint on a constant type.
210 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211 if (Attribute type = typeOp.typeAttr())
212 predList.emplace_back(pos, builder.getTypeConstraint(type));
213 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214 if (Attribute typeAttr = typeOp.typesAttr())
215 predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216 }
217 }
218
219 /// Collect the tree predicates anchored at the given value.
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,Position * pos)220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221 Value val, PredicateBuilder &builder,
222 DenseMap<Value, Position *> &inputs,
223 Position *pos) {
224 // Make sure this input value is accessible to the rewrite.
225 auto it = inputs.try_emplace(val, pos);
226 if (!it.second) {
227 // If this is an input value that has been visited in the tree, add a
228 // constraint to ensure that both instances refer to the same value.
229 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230 pdl::TypeOp>(val.getDefiningOp())) {
231 auto minMaxPositions =
232 std::minmax(pos, it.first->second, comparePosDepth);
233 predList.emplace_back(minMaxPositions.second,
234 builder.getEqualTo(minMaxPositions.first));
235 }
236 return;
237 }
238
239 TypeSwitch<Position *>(pos)
240 .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241 getTreePredicates(predList, val, builder, inputs, pos);
242 })
243 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244 getOperandTreePredicates(predList, val, builder, inputs, pos);
245 })
246 .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248
getAttributePredicates(pdl::AttributeOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)249 static void getAttributePredicates(pdl::AttributeOp op,
250 std::vector<PositionalPredicate> &predList,
251 PredicateBuilder &builder,
252 DenseMap<Value, Position *> &inputs) {
253 Position *&attrPos = inputs[op];
254 if (attrPos)
255 return;
256 Attribute value = op.valueAttr();
257 assert(value && "expected non-tree `pdl.attribute` to contain a value");
258 attrPos = builder.getAttributeLiteral(value);
259 }
260
getConstraintPredicates(pdl::ApplyNativeConstraintOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262 std::vector<PositionalPredicate> &predList,
263 PredicateBuilder &builder,
264 DenseMap<Value, Position *> &inputs) {
265 OperandRange arguments = op.args();
266
267 std::vector<Position *> allPositions;
268 allPositions.reserve(arguments.size());
269 for (Value arg : arguments)
270 allPositions.push_back(inputs.lookup(arg));
271
272 // Push the constraint to the furthest position.
273 Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274 comparePosDepth);
275 PredicateBuilder::Predicate pred =
276 builder.getConstraint(op.name(), allPositions);
277 predList.emplace_back(pos, pred);
278 }
279
getResultPredicates(pdl::ResultOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)280 static void getResultPredicates(pdl::ResultOp op,
281 std::vector<PositionalPredicate> &predList,
282 PredicateBuilder &builder,
283 DenseMap<Value, Position *> &inputs) {
284 Position *&resultPos = inputs[op];
285 if (resultPos)
286 return;
287
288 // Ensure that the result isn't null.
289 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
290 resultPos = builder.getResult(parentPos, op.index());
291 predList.emplace_back(resultPos, builder.getIsNotNull());
292 }
293
getResultPredicates(pdl::ResultsOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)294 static void getResultPredicates(pdl::ResultsOp op,
295 std::vector<PositionalPredicate> &predList,
296 PredicateBuilder &builder,
297 DenseMap<Value, Position *> &inputs) {
298 Position *&resultPos = inputs[op];
299 if (resultPos)
300 return;
301
302 // Ensure that the result isn't null if the result has an index.
303 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
304 bool isVariadic = op.getType().isa<pdl::RangeType>();
305 Optional<unsigned> index = op.index();
306 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
307 if (index)
308 predList.emplace_back(resultPos, builder.getIsNotNull());
309 }
310
getTypePredicates(Value typeValue,function_ref<Attribute ()> typeAttrFn,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)311 static void getTypePredicates(Value typeValue,
312 function_ref<Attribute()> typeAttrFn,
313 PredicateBuilder &builder,
314 DenseMap<Value, Position *> &inputs) {
315 Position *&typePos = inputs[typeValue];
316 if (typePos)
317 return;
318 Attribute typeAttr = typeAttrFn();
319 assert(typeAttr &&
320 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
321 typePos = builder.getTypeLiteral(typeAttr);
322 }
323
324 /// Collect all of the predicates that cannot be determined via walking the
325 /// tree.
getNonTreePredicates(pdl::PatternOp pattern,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)326 static void getNonTreePredicates(pdl::PatternOp pattern,
327 std::vector<PositionalPredicate> &predList,
328 PredicateBuilder &builder,
329 DenseMap<Value, Position *> &inputs) {
330 for (Operation &op : pattern.body().getOps()) {
331 TypeSwitch<Operation *>(&op)
332 .Case([&](pdl::AttributeOp attrOp) {
333 getAttributePredicates(attrOp, predList, builder, inputs);
334 })
335 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
336 getConstraintPredicates(constraintOp, predList, builder, inputs);
337 })
338 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
339 getResultPredicates(resultOp, predList, builder, inputs);
340 })
341 .Case([&](pdl::TypeOp typeOp) {
342 getTypePredicates(
343 typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
344 })
345 .Case([&](pdl::TypesOp typeOp) {
346 getTypePredicates(
347 typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
348 });
349 }
350 }
351
352 namespace {
353
354 /// An op accepting a value at an optional index.
355 struct OpIndex {
356 Value parent;
357 Optional<unsigned> index;
358 };
359
360 /// The parent and operand index of each operation for each root, stored
361 /// as a nested map [root][operation].
362 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
363
364 } // namespace
365
366 /// Given a pattern, determines the set of roots present in this pattern.
367 /// These are the operations whose results are not consumed by other operations.
detectRoots(pdl::PatternOp pattern)368 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
369 // First, collect all the operations that are used as operands
370 // to other operations. These are not roots by default.
371 DenseSet<Value> used;
372 for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
373 for (Value operand : operationOp.operands())
374 TypeSwitch<Operation *>(operand.getDefiningOp())
375 .Case<pdl::ResultOp, pdl::ResultsOp>(
376 [&used](auto resultOp) { used.insert(resultOp.parent()); });
377 }
378
379 // Remove the specified root from the use set, so that we can
380 // always select it as a root, even if it is used by other operations.
381 if (Value root = pattern.getRewriter().root())
382 used.erase(root);
383
384 // Finally, collect all the unused operations.
385 SmallVector<Value> roots;
386 for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
387 if (!used.contains(operationOp))
388 roots.push_back(operationOp);
389
390 return roots;
391 }
392
393 /// Given a list of candidate roots, builds the cost graph for connecting them.
394 /// The graph is formed by traversing the DAG of operations starting from each
395 /// root and marking the depth of each connector value (operand). Then we join
396 /// the candidate roots based on the common connector values, taking the one
397 /// with the minimum depth. Along the way, we compute, for each candidate root,
398 /// a mapping from each operation (in the DAG underneath this root) to its
399 /// parent operation and the corresponding operand index.
buildCostGraph(ArrayRef<Value> roots,RootOrderingGraph & graph,ParentMaps & parentMaps)400 static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
401 ParentMaps &parentMaps) {
402
403 // The entry of a queue. The entry consists of the following items:
404 // * the value in the DAG underneath the root;
405 // * the parent of the value;
406 // * the operand index of the value in its parent;
407 // * the depth of the visited value.
408 struct Entry {
409 Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
410 : value(value), parent(parent), index(index), depth(depth) {}
411
412 Value value;
413 Value parent;
414 Optional<unsigned> index;
415 unsigned depth;
416 };
417
418 // A root of a value and its depth (distance from root to the value).
419 struct RootDepth {
420 Value root;
421 unsigned depth = 0;
422 };
423
424 // Map from candidate connector values to their roots and depths. Using a
425 // small vector with 1 entry because most values belong to a single root.
426 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
427
428 // Perform a breadth-first traversal of the op DAG rooted at each root.
429 for (Value root : roots) {
430 // The queue of visited values. A value may be present multiple times in
431 // the queue, for multiple parents. We only accept the first occurrence,
432 // which is guaranteed to have the lowest depth.
433 std::queue<Entry> toVisit;
434 toVisit.emplace(root, Value(), 0, 0);
435
436 // The map from value to its parent for the current root.
437 DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
438
439 while (!toVisit.empty()) {
440 Entry entry = toVisit.front();
441 toVisit.pop();
442 // Skip if already visited.
443 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
444 continue;
445
446 // Mark the root and depth of the value.
447 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
448
449 // Traverse the operands of an operation and result ops.
450 // We intentionally do not traverse attributes and types, because those
451 // are expensive to join on.
452 TypeSwitch<Operation *>(entry.value.getDefiningOp())
453 .Case<pdl::OperationOp>([&](auto operationOp) {
454 OperandRange operands = operationOp.operands();
455 // Special case when we pass all the operands in one range.
456 // For those, the index is empty.
457 if (operands.size() == 1 &&
458 operands[0].getType().isa<pdl::RangeType>()) {
459 toVisit.emplace(operands[0], entry.value, llvm::None,
460 entry.depth + 1);
461 return;
462 }
463
464 // Default case: visit all the operands.
465 for (const auto &p : llvm::enumerate(operationOp.operands()))
466 toVisit.emplace(p.value(), entry.value, p.index(),
467 entry.depth + 1);
468 })
469 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
470 toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
471 entry.depth);
472 });
473 }
474 }
475
476 // Now build the cost graph.
477 // This is simply a minimum over all depths for the target root.
478 unsigned nextID = 0;
479 for (const auto &connectorRootsDepths : connectorsRootsDepths) {
480 Value value = connectorRootsDepths.first;
481 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
482 // If there is only one root for this value, this will not trigger
483 // any edges in the cost graph (a perf optimization).
484 if (rootsDepths.size() == 1)
485 continue;
486
487 for (const RootDepth &p : rootsDepths) {
488 for (const RootDepth &q : rootsDepths) {
489 if (&p == &q)
490 continue;
491 // Insert or retrieve the property of edge from p to q.
492 RootOrderingEntry &entry = graph[q.root][p.root];
493 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
494 if (!entry.connector)
495 entry.cost.second = nextID++;
496 entry.cost.first = q.depth;
497 entry.connector = value;
498 }
499 }
500 }
501 }
502
503 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
504 "the pattern contains a candidate root disconnected from the others");
505 }
506
507 /// Returns true if the operand at the given index needs to be queried using an
508 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
useOperandGroup(pdl::OperationOp op,unsigned index)509 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
510 OperandRange operands = op.operands();
511 assert(index < operands.size() && "operand index out of range");
512 for (unsigned i = 0; i <= index; ++i)
513 if (operands[i].getType().isa<pdl::RangeType>())
514 return true;
515 return false;
516 }
517
518 /// Visit a node during upward traversal.
visitUpward(std::vector<PositionalPredicate> & predList,OpIndex opIndex,PredicateBuilder & builder,DenseMap<Value,Position * > & valueToPosition,Position * & pos,unsigned rootID)519 static void visitUpward(std::vector<PositionalPredicate> &predList,
520 OpIndex opIndex, PredicateBuilder &builder,
521 DenseMap<Value, Position *> &valueToPosition,
522 Position *&pos, unsigned rootID) {
523 Value value = opIndex.parent;
524 TypeSwitch<Operation *>(value.getDefiningOp())
525 .Case<pdl::OperationOp>([&](auto operationOp) {
526 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
527
528 // Get users and iterate over them.
529 Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
530 Position *foreachPos = builder.getForEach(usersPos, rootID);
531 OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
532
533 // Compare the operand(s) of the user against the input value(s).
534 Position *operandPos;
535 if (!opIndex.index) {
536 // We are querying all the operands of the operation.
537 operandPos = builder.getAllOperands(opPos);
538 } else if (useOperandGroup(operationOp, *opIndex.index)) {
539 // We are querying an operand group.
540 Type type = operationOp.operands()[*opIndex.index].getType();
541 bool variadic = type.isa<pdl::RangeType>();
542 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
543 } else {
544 // We are querying an individual operand.
545 operandPos = builder.getOperand(opPos, *opIndex.index);
546 }
547 predList.emplace_back(operandPos, builder.getEqualTo(pos));
548
549 // Guard against duplicate upward visits. These are not possible,
550 // because if this value was already visited, it would have been
551 // cheaper to start the traversal at this value rather than at the
552 // `connector`, violating the optimality of our spanning tree.
553 bool inserted = valueToPosition.try_emplace(value, opPos).second;
554 (void)inserted;
555 assert(inserted && "duplicate upward visit");
556
557 // Obtain the tree predicates at the current value.
558 getTreePredicates(predList, value, builder, valueToPosition, opPos,
559 opIndex.index);
560
561 // Update the position
562 pos = opPos;
563 })
564 .Case<pdl::ResultOp>([&](auto resultOp) {
565 // Traverse up an individual result.
566 auto *opPos = dyn_cast<OperationPosition>(pos);
567 assert(opPos && "operations and results must be interleaved");
568 pos = builder.getResult(opPos, *opIndex.index);
569
570 // Insert the result position in case we have not visited it yet.
571 valueToPosition.try_emplace(value, pos);
572 })
573 .Case<pdl::ResultsOp>([&](auto resultOp) {
574 // Traverse up a group of results.
575 auto *opPos = dyn_cast<OperationPosition>(pos);
576 assert(opPos && "operations and results must be interleaved");
577 bool isVariadic = value.getType().isa<pdl::RangeType>();
578 if (opIndex.index)
579 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
580 else
581 pos = builder.getAllResults(opPos);
582
583 // Insert the result position in case we have not visited it yet.
584 valueToPosition.try_emplace(value, pos);
585 });
586 }
587
588 /// Given a pattern operation, build the set of matcher predicates necessary to
589 /// match this pattern.
buildPredicateList(pdl::PatternOp pattern,PredicateBuilder & builder,std::vector<PositionalPredicate> & predList,DenseMap<Value,Position * > & valueToPosition)590 static Value buildPredicateList(pdl::PatternOp pattern,
591 PredicateBuilder &builder,
592 std::vector<PositionalPredicate> &predList,
593 DenseMap<Value, Position *> &valueToPosition) {
594 SmallVector<Value> roots = detectRoots(pattern);
595
596 // Build the root ordering graph and compute the parent maps.
597 RootOrderingGraph graph;
598 ParentMaps parentMaps;
599 buildCostGraph(roots, graph, parentMaps);
600 LLVM_DEBUG({
601 llvm::dbgs() << "Graph:\n";
602 for (auto &target : graph) {
603 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
604 << "\n";
605 for (auto &source : target.second) {
606 RootOrderingEntry &entry = source.second;
607 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
608 << ":" << entry.cost.second << " via "
609 << entry.connector.getLoc() << "\n";
610 }
611 }
612 });
613
614 // Solve the optimal branching problem for each candidate root, or use the
615 // provided one.
616 Value bestRoot = pattern.getRewriter().root();
617 OptimalBranching::EdgeList bestEdges;
618 if (!bestRoot) {
619 unsigned bestCost = 0;
620 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
621 for (Value root : roots) {
622 OptimalBranching solver(graph, root);
623 unsigned cost = solver.solve();
624 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
625 if (!bestRoot || bestCost > cost) {
626 bestCost = cost;
627 bestRoot = root;
628 bestEdges = solver.preOrderTraversal(roots);
629 }
630 }
631 } else {
632 OptimalBranching solver(graph, bestRoot);
633 solver.solve();
634 bestEdges = solver.preOrderTraversal(roots);
635 }
636
637 // Print the best solution.
638 LLVM_DEBUG({
639 llvm::dbgs() << "Best tree:\n";
640 for (const std::pair<Value, Value> &edge : bestEdges) {
641 llvm::dbgs() << " * " << edge.first;
642 if (edge.second)
643 llvm::dbgs() << " <- " << edge.second;
644 llvm::dbgs() << "\n";
645 }
646 });
647
648 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
649 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
650
651 // The best root is the starting point for the traversal. Get the tree
652 // predicates for the DAG rooted at bestRoot.
653 getTreePredicates(predList, bestRoot, builder, valueToPosition,
654 builder.getRoot());
655
656 // Traverse the selected optimal branching. For all edges in order, traverse
657 // up starting from the connector, until the candidate root is reached, and
658 // call getTreePredicates at every node along the way.
659 for (const auto &it : llvm::enumerate(bestEdges)) {
660 Value target = it.value().first;
661 Value source = it.value().second;
662
663 // Check if we already visited the target root. This happens in two cases:
664 // 1) the initial root (bestRoot);
665 // 2) a root that is dominated by (contained in the subtree rooted at) an
666 // already visited root.
667 if (valueToPosition.count(target))
668 continue;
669
670 // Determine the connector.
671 Value connector = graph[target][source].connector;
672 assert(connector && "invalid edge");
673 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
674 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
675 Position *pos = valueToPosition.lookup(connector);
676 assert(pos && "connector has not been traversed yet");
677
678 // Traverse from the connector upwards towards the target root.
679 for (Value value = connector; value != target;) {
680 OpIndex opIndex = parentMap.lookup(value);
681 assert(opIndex.parent && "missing parent");
682 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
683 value = opIndex.parent;
684 }
685 }
686
687 getNonTreePredicates(pattern, predList, builder, valueToPosition);
688
689 return bestRoot;
690 }
691
692 //===----------------------------------------------------------------------===//
693 // Pattern Predicate Tree Merging
694 //===----------------------------------------------------------------------===//
695
696 namespace {
697
698 /// This class represents a specific predicate applied to a position, and
699 /// provides hashing and ordering operators. This class allows for computing a
700 /// frequence sum and ordering predicates based on a cost model.
701 struct OrderedPredicate {
OrderedPredicate__anon12da54731511::OrderedPredicate702 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
703 : position(ip.first), question(ip.second) {}
OrderedPredicate__anon12da54731511::OrderedPredicate704 OrderedPredicate(const PositionalPredicate &ip)
705 : position(ip.position), question(ip.question) {}
706
707 /// The position this predicate is applied to.
708 Position *position;
709
710 /// The question that is applied by this predicate onto the position.
711 Qualifier *question;
712
713 /// The first and second order benefit sums.
714 /// The primary sum is the number of occurrences of this predicate among all
715 /// of the patterns.
716 unsigned primary = 0;
717 /// The secondary sum is a squared summation of the primary sum of all of the
718 /// predicates within each pattern that contains this predicate. This allows
719 /// for favoring predicates that are more commonly shared within a pattern, as
720 /// opposed to those shared across patterns.
721 unsigned secondary = 0;
722
723 /// The tie breaking ID, used to preserve a deterministic (insertion) order
724 /// among all the predicates with the same priority, depth, and position /
725 /// predicate dependency.
726 unsigned id = 0;
727
728 /// A map between a pattern operation and the answer to the predicate question
729 /// within that pattern.
730 DenseMap<Operation *, Qualifier *> patternToAnswer;
731
732 /// Returns true if this predicate is ordered before `rhs`, based on the cost
733 /// model.
operator <__anon12da54731511::OrderedPredicate734 bool operator<(const OrderedPredicate &rhs) const {
735 // Sort by:
736 // * higher first and secondary order sums
737 // * lower depth
738 // * lower position dependency
739 // * lower predicate dependency
740 // * lower tie breaking ID
741 auto *rhsPos = rhs.position;
742 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
743 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
744 std::make_tuple(rhs.primary, rhs.secondary,
745 position->getOperationDepth(), position->getKind(),
746 question->getKind(), id);
747 }
748 };
749
750 /// A DenseMapInfo for OrderedPredicate based solely on the position and
751 /// question.
752 struct OrderedPredicateDenseInfo {
753 using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
754
getEmptyKey__anon12da54731511::OrderedPredicateDenseInfo755 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
getTombstoneKey__anon12da54731511::OrderedPredicateDenseInfo756 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
isEqual__anon12da54731511::OrderedPredicateDenseInfo757 static bool isEqual(const OrderedPredicate &lhs,
758 const OrderedPredicate &rhs) {
759 return lhs.position == rhs.position && lhs.question == rhs.question;
760 }
getHashValue__anon12da54731511::OrderedPredicateDenseInfo761 static unsigned getHashValue(const OrderedPredicate &p) {
762 return llvm::hash_combine(p.position, p.question);
763 }
764 };
765
766 /// This class wraps a set of ordered predicates that are used within a specific
767 /// pattern operation.
768 struct OrderedPredicateList {
OrderedPredicateList__anon12da54731511::OrderedPredicateList769 OrderedPredicateList(pdl::PatternOp pattern, Value root)
770 : pattern(pattern), root(root) {}
771
772 pdl::PatternOp pattern;
773 Value root;
774 DenseSet<OrderedPredicate *> predicates;
775 };
776 } // namespace
777
778 /// Returns true if the given matcher refers to the same predicate as the given
779 /// ordered predicate. This means that the position and questions of the two
780 /// match.
isSamePredicate(MatcherNode * node,OrderedPredicate * predicate)781 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
782 return node->getPosition() == predicate->position &&
783 node->getQuestion() == predicate->question;
784 }
785
786 /// Get or insert a child matcher for the given parent switch node, given a
787 /// predicate and parent pattern.
getOrCreateChild(SwitchNode * node,OrderedPredicate * predicate,pdl::PatternOp pattern)788 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
789 OrderedPredicate *predicate,
790 pdl::PatternOp pattern) {
791 assert(isSamePredicate(node, predicate) &&
792 "expected matcher to equal the given predicate");
793
794 auto it = predicate->patternToAnswer.find(pattern);
795 assert(it != predicate->patternToAnswer.end() &&
796 "expected pattern to exist in predicate");
797 return node->getChildren().insert({it->second, nullptr}).first->second;
798 }
799
800 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
801 /// order. A pattern will traverse as far as possible using common predicates
802 /// and then either diverge from the CFG or reach the end of a branch and start
803 /// creating new nodes.
propagatePattern(std::unique_ptr<MatcherNode> & node,OrderedPredicateList & list,std::vector<OrderedPredicate * >::iterator current,std::vector<OrderedPredicate * >::iterator end)804 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
805 OrderedPredicateList &list,
806 std::vector<OrderedPredicate *>::iterator current,
807 std::vector<OrderedPredicate *>::iterator end) {
808 if (current == end) {
809 // We've hit the end of a pattern, so create a successful result node.
810 node =
811 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
812
813 // If the pattern doesn't contain this predicate, ignore it.
814 } else if (list.predicates.find(*current) == list.predicates.end()) {
815 propagatePattern(node, list, std::next(current), end);
816
817 // If the current matcher node is invalid, create a new one for this
818 // position and continue propagation.
819 } else if (!node) {
820 // Create a new node at this position and continue
821 node = std::make_unique<SwitchNode>((*current)->position,
822 (*current)->question);
823 propagatePattern(
824 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
825 list, std::next(current), end);
826
827 // If the matcher has already been created, and it is for this predicate we
828 // continue propagation to the child.
829 } else if (isSamePredicate(node.get(), *current)) {
830 propagatePattern(
831 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
832 list, std::next(current), end);
833
834 // If the matcher doesn't match the current predicate, insert a branch as
835 // the common set of matchers has diverged.
836 } else {
837 propagatePattern(node->getFailureNode(), list, current, end);
838 }
839 }
840
841 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
842 /// `node` is updated in-place if it is a switch.
foldSwitchToBool(std::unique_ptr<MatcherNode> & node)843 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
844 if (!node)
845 return;
846
847 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
848 SwitchNode::ChildMapT &children = switchNode->getChildren();
849 for (auto &it : children)
850 foldSwitchToBool(it.second);
851
852 // If the node only contains one child, collapse it into a boolean predicate
853 // node.
854 if (children.size() == 1) {
855 auto childIt = children.begin();
856 node = std::make_unique<BoolNode>(
857 node->getPosition(), node->getQuestion(), childIt->first,
858 std::move(childIt->second), std::move(node->getFailureNode()));
859 }
860 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
861 foldSwitchToBool(boolNode->getSuccessNode());
862 }
863
864 foldSwitchToBool(node->getFailureNode());
865 }
866
867 /// Insert an exit node at the end of the failure path of the `root`.
insertExitNode(std::unique_ptr<MatcherNode> * root)868 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
869 while (*root)
870 root = &(*root)->getFailureNode();
871 *root = std::make_unique<ExitNode>();
872 }
873
874 /// Given a module containing PDL pattern operations, generate a matcher tree
875 /// using the patterns within the given module and return the root matcher node.
876 std::unique_ptr<MatcherNode>
generateMatcherTree(ModuleOp module,PredicateBuilder & builder,DenseMap<Value,Position * > & valueToPosition)877 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
878 DenseMap<Value, Position *> &valueToPosition) {
879 // The set of predicates contained within the pattern operations of the
880 // module.
881 struct PatternPredicates {
882 PatternPredicates(pdl::PatternOp pattern, Value root,
883 std::vector<PositionalPredicate> predicates)
884 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
885
886 /// A pattern.
887 pdl::PatternOp pattern;
888
889 /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
890 Value root;
891
892 /// The extracted predicates for this pattern and root.
893 std::vector<PositionalPredicate> predicates;
894 };
895
896 SmallVector<PatternPredicates, 16> patternsAndPredicates;
897 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
898 std::vector<PositionalPredicate> predicateList;
899 Value root =
900 buildPredicateList(pattern, builder, predicateList, valueToPosition);
901 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
902 }
903
904 // Associate a pattern result with each unique predicate.
905 DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
906 for (auto &patternAndPredList : patternsAndPredicates) {
907 for (auto &predicate : patternAndPredList.predicates) {
908 auto it = uniqued.insert(predicate);
909 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
910 predicate.answer);
911 // Mark the insertion order (0-based indexing).
912 if (it.second)
913 it.first->id = uniqued.size() - 1;
914 }
915 }
916
917 // Associate each pattern to a set of its ordered predicates for later lookup.
918 std::vector<OrderedPredicateList> lists;
919 lists.reserve(patternsAndPredicates.size());
920 for (auto &patternAndPredList : patternsAndPredicates) {
921 OrderedPredicateList list(patternAndPredList.pattern,
922 patternAndPredList.root);
923 for (auto &predicate : patternAndPredList.predicates) {
924 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
925 list.predicates.insert(orderedPredicate);
926
927 // Increment the primary sum for each reference to a particular predicate.
928 ++orderedPredicate->primary;
929 }
930 lists.push_back(std::move(list));
931 }
932
933 // For a particular pattern, get the total primary sum and add it to the
934 // secondary sum of each predicate. Square the primary sums to emphasize
935 // shared predicates within rather than across patterns.
936 for (auto &list : lists) {
937 unsigned total = 0;
938 for (auto *predicate : list.predicates)
939 total += predicate->primary * predicate->primary;
940 for (auto *predicate : list.predicates)
941 predicate->secondary += total;
942 }
943
944 // Sort the set of predicates now that the cost primary and secondary sums
945 // have been computed.
946 std::vector<OrderedPredicate *> ordered;
947 ordered.reserve(uniqued.size());
948 for (auto &ip : uniqued)
949 ordered.push_back(&ip);
950 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
951 return *lhs < *rhs;
952 });
953
954 // Build the matchers for each of the pattern predicate lists.
955 std::unique_ptr<MatcherNode> root;
956 for (OrderedPredicateList &list : lists)
957 propagatePattern(root, list, ordered.begin(), ordered.end());
958
959 // Collapse the graph and insert the exit node.
960 foldSwitchToBool(root);
961 insertExitNode(&root);
962 return root;
963 }
964
965 //===----------------------------------------------------------------------===//
966 // MatcherNode
967 //===----------------------------------------------------------------------===//
968
MatcherNode(TypeID matcherTypeID,Position * p,Qualifier * q,std::unique_ptr<MatcherNode> failureNode)969 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
970 std::unique_ptr<MatcherNode> failureNode)
971 : position(p), question(q), failureNode(std::move(failureNode)),
972 matcherTypeID(matcherTypeID) {}
973
974 //===----------------------------------------------------------------------===//
975 // BoolNode
976 //===----------------------------------------------------------------------===//
977
BoolNode(Position * position,Qualifier * question,Qualifier * answer,std::unique_ptr<MatcherNode> successNode,std::unique_ptr<MatcherNode> failureNode)978 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
979 std::unique_ptr<MatcherNode> successNode,
980 std::unique_ptr<MatcherNode> failureNode)
981 : MatcherNode(TypeID::get<BoolNode>(), position, question,
982 std::move(failureNode)),
983 answer(answer), successNode(std::move(successNode)) {}
984
985 //===----------------------------------------------------------------------===//
986 // SuccessNode
987 //===----------------------------------------------------------------------===//
988
SuccessNode(pdl::PatternOp pattern,Value root,std::unique_ptr<MatcherNode> failureNode)989 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
990 std::unique_ptr<MatcherNode> failureNode)
991 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
992 /*question=*/nullptr, std::move(failureNode)),
993 pattern(pattern), root(root) {}
994
995 //===----------------------------------------------------------------------===//
996 // SwitchNode
997 //===----------------------------------------------------------------------===//
998
SwitchNode(Position * position,Qualifier * question)999 SwitchNode::SwitchNode(Position *position, Qualifier *question)
1000 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1001